[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)

## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
This commit is contained in:
Keren Zhou
2022-11-06 22:59:03 -08:00
committed by GitHub
parent a4ff0c362c
commit fdd59900f7
26 changed files with 507 additions and 339 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
namespace {
// Create a 32-bit integer constant.
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
@@ -63,17 +65,17 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
}
// Create a index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
short width, int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
@@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
return linearIndex;
}
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
@@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
return builder.launch(rewriter, loc, void_ty(ctx));
}
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object.
// We need to store strides as Values but not integers because the
// extract_slice instruction can take a slice at artibary offsets.
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a to be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
// we store strides into an attribute array of integers, the information
// cannot pass through block argument assignment because attributes are
// associated with operations but not Values.
// TODO(Keren): We may need to figure out a way to store strides as integers
// if we want to support more optimizations.
SmallVector<Value>
strides; // i32 int. The strides of the shared memory object.
SharedMemoryObject(Value base, ArrayRef<Value> strides)
: base(base), strides(strides.begin(), strides.end()) {}
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
ConversionPatternRewriter &rewriter)
: base(base) {
auto stride = 1;
for (auto dim : llvm::reverse(shape)) {
this->strides.emplace_back(i32_val(stride));
stride *= dim;
}
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides));
}
SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
elems.append(strides.begin(), strides.end());
return elems;
}
SmallVector<Type> getTypes() const {
SmallVector<Type> types;
types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
return types;
}
};
struct ConvertTritonGPUOpToLLVMPatternBase {
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct,
@@ -489,6 +535,16 @@ public:
return linear;
}
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = idx_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
return ret;
}
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc,
@@ -671,6 +727,25 @@ public:
return base;
}
static SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
return SharedMemoryObject(/*base=*/elems[0],
/*strides=*/{elems.begin() + 1, elems.end()});
}
static Value
getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy);
}
protected:
const Allocation *allocation;
Value smem;
@@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto llvmElemTy =
getTypeConverter()->convertType(resultTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal =
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
rewriter.replaceOp(op, resultVal);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj =
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
struct ExtractSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = extract_slice %src[%offsets]
Location loc = op->getLoc();
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by ExtractSliceOpConversion");
// axis > 0 will result in non-contiguous memory access if the result
// tensor is an alias of the source tensor.
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
// Example:
// %dst = extract_slice %src, %index {axis = 0}
// src.shape = [11, 2, 3, 4, 1]
// offset = %index * 2 * 3 * 4 * 1
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
auto base = product<int64_t>(dstTy.getShape());
auto baseVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), base);
Value offset = mul(adaptor.index(), baseVal);
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
rewriter.replaceOp(op, resultVal);
// newBase = base + offset
// Triton support either static and dynamic offsets
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
SmallVector<Value, 4> offsetVals;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i))
offsetVals.emplace_back(adaptor.offsets()[i]);
else
offsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
}
// Compute the offset based on the original strides of the shared memory
// object
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides);
// newShape = rank_reduce(shape)
// Triton only supports static tensor sizes
SmallVector<Value, 4> strideVals;
auto staticSizes = op.static_sizes();
for (auto i = 0; i < op.static_sizes().size(); ++i) {
if (op.getStaticSize(i) != 1) {
strideVals.emplace_back(smemObj.strides[i]);
}
}
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
auto resTy = op.getType().dyn_cast<RankedTensorType>();
smemObj =
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
@@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
@@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
SmallVector<Value> wordVecs(numWordsEachRep);
// TODO: We should get less barriers if it is handled by membar pass
@@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
}
}
}
barrier();
rewriter.replaceOp(op, smemBase);
// Barrier is not necessary.
// The membar pass knows that it writes to shared memory and will handle it
// properly.
rewriter.replaceOp(op, retVal);
return success();
}
@@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()), kOrder(kOrder),
tileShape(tileShape.begin(), tileShape.end()),
@@ -2393,8 +2491,8 @@ public:
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
cTileStride = tileShape[order[1]];
sTileStride = tileShape[order[0]];
cTileStride = smemStrides[order[0]];
sTileStride = smemStrides[order[1]];
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
@@ -2497,8 +2595,7 @@ public:
for (int i = 0; i < numPtr; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase);
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)),
mul(sOff, i32_val(sTileStride)));
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
}
return offs;
@@ -2534,7 +2631,7 @@ public:
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride)));
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
}
}
return offs;
@@ -2574,7 +2671,7 @@ public:
// To prevent out-of-bound access when tile is too small.
cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride)));
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
}
}
}
@@ -2608,14 +2705,15 @@ public:
Value ptr = getPtr(ptrIdx);
if (canUseLdmatrix) {
int sOffset =
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
Value sOffset =
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
PTXBuilder builder;
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
// thread.
auto resArgs = builder.newListOperand(4, "=r");
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
->o("trans", needTrans /*predicate*/)
@@ -2640,26 +2738,24 @@ public:
needTrans) { // Use lds.32 to load tf32 matrices
Value ptr2 = getPtr(ptrIdx + 1);
assert(sMatStride == 1);
int sOffsetElem =
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
int sOffsetArrElem = sMatStride * sMatShape;
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
Value elems[4];
Type elemTy = type::f32Ty(ctx);
if (kOrder == 1) {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[2] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
} else {
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
elems[1] =
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
elems[3] =
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
}
return {elems[0], elems[1], elems[2], elems[3]};
@@ -2680,9 +2776,11 @@ public:
};
assert(sMatStride == 1);
int sOffsetElem =
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
std::array<Value, 4> i8v4Elems;
std::array<Value, 4> i32Elems;
@@ -2692,16 +2790,14 @@ public:
Value i8Elems[4][4];
Type elemTy = type::i8Ty(ctx);
if (kOrder == 1) {
Value offset = i32_val(sOffsetElem);
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
offset = i32_val(sOffsetElem + sOffsetArrElem);
for (int i = 2; i < 4; ++i)
for (int j = 0; j < 4; ++j)
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
i8Elems[i][j] =
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
@@ -2710,16 +2806,14 @@ public:
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
}
} else { // k first
Value offset = i32_val(sOffsetElem);
for (int j = 0; j < 4; ++j)
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
offset = i32_val(sOffsetElem + sOffsetArrElem);
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
for (int j = 0; j < 4; ++j)
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < 4; ++e)
@@ -2752,8 +2846,8 @@ private:
int cMatShape;
int sMatShape;
int cTileStride;
int sTileStride;
Value cTileStride;
Value sTileStride;
bool needTrans;
bool canUseLdmatrix;
@@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper {
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const;
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
ConversionPatternRewriter &rewriter) const;
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
Location loc, ConversionPatternRewriter &rewriter) const;
// Loading $c to registers, returns a LLVM::Struct.
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
@@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper {
}
// Loading $a from smem to registers, returns a LLVM::Struct.
Value loadA(Value tensor, Value llTensor) const {
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = aTensorTy.getShape();
@@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper {
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
// load from smem
loadFn = getLoadMatrixFn(
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
@@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper {
}
// Loading $b from smem to registers, returns a LLVM::Struct.
Value loadB(Value tensor, Value llTensor) {
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
ValueTable hb;
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
@@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper {
int numRepN = getNumRepN(tensorTy, shape[1]);
auto loadFn = getLoadMatrixFn(
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
@@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper {
private:
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
ArrayRef<int> matShape, Value warpId,
ValueTable &vals) const {
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
Value warpId, ValueTable &vals) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -3507,10 +3601,10 @@ private:
// (a, b) is the coordinate.
auto load = [=, &vals, &ld2](int a, int b) {
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
tensorTy.getShape() /*tileShape*/, instrShape,
matShape, perPhase, maxPhase, elemBytes,
rewriter, typeConverter, loc);
MMA16816SmemLoader loader(
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
const int numPtrs = loader.getNumPtr();
@@ -3519,8 +3613,8 @@ private:
Type smemPtrTy = helper.getShemPtrTy();
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})),
smemPtrTy);
}
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
@@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
if (mmaLayout.getVersion() == 2) {
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
@@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, adaptor.src());
res = mmaHelper.loadA(src, smemObj);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, adaptor.src());
res = mmaHelper.loadB(src, smemObj);
}
} else if (mmaLayout.getVersion() == 1) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
res =
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
adaptor.src(), loc, rewriter);
res =
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
@@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
loadedA = adaptor.a();
loadedB = adaptor.b();
} else {
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
SharedMemoryObject smemA =
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
SharedMemoryObject smemB =
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
loadedA = mmaHelper.loadA(op.a(), smemA);
loadedB = mmaHelper.loadB(op.b(), smemB);
}
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
@@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
}
Value DotOpMmaV1ConversionHelper::loadA(
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
// smem
Value smem = smemObj.base;
auto strides = smemObj.strides;
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
@@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
int vecA = sharedLayout.getVec();
int strideAM = isARow ? shape[1] : 1;
int strideAK = isARow ? 1 : shape[0];
int strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK;
Value strideAM = isARow ? strides[0] : i32_val(1);
Value strideAK = isARow ? i32_val(1) : strides[1];
Value strideA0 = isARow ? strideAK : strideAM;
Value strideA1 = isARow ? strideAM : strideAK;
int strideRepM = wpt[0] * fpw[0] * 8;
int strideRepK = 1;
@@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
offA0I = udiv(offA0I, i32_val(vecA));
offA0I = xor_(offA0I, phaseA);
offA0I = xor_(offA0I, i32_val(vecA));
offA[i] =
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
}
Type f16x2Ty = vec_ty(f16_ty, 2);
@@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
int stepAM = isARow ? m : m / numPtrA * numPtrA;
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
Value pa = gep(f16PtrTy, thePtrA,
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
mul(i32_val(stepAK), strideAK));
Value pa = gep(f16PtrTy, thePtrA, offset);
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
Value ha = load(bitcast(pa, aPtrTy));
// record lds that needs to be moved
@@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
}
Value DotOpMmaV1ConversionHelper::loadB(
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
ConversionPatternRewriter &rewriter) const {
// smem
Value smem = smemObj.base;
auto strides = smemObj.strides;
auto *ctx = rewriter.getContext();
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shape = tensorTy.getShape();
@@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
int vecB = sharedLayout.getVec();
int strideBN = isBRow ? 1 : shape[0];
int strideBK = isBRow ? shape[1] : 1;
int strideB0 = isBRow ? strideBN : strideBK;
int strideB1 = isBRow ? strideBK : strideBN;
Value strideBN = isBRow ? i32_val(1) : strides[1];
Value strideBK = isBRow ? strides[0] : i32_val(1);
Value strideB0 = isBRow ? strideBN : strideBK;
Value strideB1 = isBRow ? strideBK : strideBN;
int strideRepN = wpt[1] * fpw[1] * 8;
int strideRepK = 1;
@@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
offB0I = udiv(offB0I, i32_val(vecB));
offB0I = xor_(offB0I, phaseB);
offB0I = mul(offB0I, i32_val(vecB));
offB[i] =
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
}
Type f16PtrTy = ptr_ty(f16_ty);
@@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
Value pb = gep(f16PtrTy, thePtrB,
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
mul(i32_val(stepBK), strideBK));
Value pb = gep(f16PtrTy, thePtrB, offset);
Value hb =
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
// record lds that needs to be moved
@@ -4171,7 +4278,17 @@ public:
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
for (auto i = 0; i < rank; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
@@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
// %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
auto dstBase = createIndexAttrConstant(rewriter, loc,
getTypeConverter()->getIndexType(),
product<int64_t>(srcTy.getShape()));
Value offset = mul(llIndex, dstBase);
auto dstPtrTy = LLVM::LLVMPointerType::get(
getTypeConverter()->convertType(resTy.getElementType()), 3);
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
SmallVector<Value, 4> offsetVals;
SmallVector<Value, 4> srcStrides;
for (auto i = 0; i < dstShape.size(); ++i) {
if (i == axis) {
offsetVals.emplace_back(llIndex);
} else {
offsetVals.emplace_back(i32_val(0));
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original dimensions of the shared memory
// object
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto dstPtrTy =
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
// %mask
SmallVector<Value> maskElems;
@@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
// swizzle is not allowd
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
@@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]];
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
@@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));
Value rowOffset =
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
// srcShape and smemObj.shape maybe different if smemObj is a
// slice of the original shared memory object.
// So we need to use the original shape to compute the offset
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
@@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion
auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// XXX(Keren): Tune CG and CA here.
// Tune CG and CA here.
auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier =
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
Value baseOffset =
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
i32_val(baseOffsetCol));
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand = ptxBuilder.newAddrOperand(
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
auto *dstOperand =
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);