[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
namespace {
|
||||
|
||||
// Create a 32-bit integer constant.
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
||||
int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
@@ -63,17 +65,17 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
}
|
||||
|
||||
// Create a index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
TypeConverter *converter, int64_t value) {
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
short width, int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
@@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
@@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
struct SharedMemoryObject {
|
||||
Value base; // i32 ptr. The start address of the shared memory object.
|
||||
// We need to store strides as Values but not integers because the
|
||||
// extract_slice instruction can take a slice at artibary offsets.
|
||||
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||
// we store strides into an attribute array of integers, the information
|
||||
// cannot pass through block argument assignment because attributes are
|
||||
// associated with operations but not Values.
|
||||
// TODO(Keren): We may need to figure out a way to store strides as integers
|
||||
// if we want to support more optimizations.
|
||||
SmallVector<Value>
|
||||
strides; // i32 int. The strides of the shared memory object.
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<Value> strides)
|
||||
: base(base), strides(strides.begin(), strides.end()) {}
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
auto stride = 1;
|
||||
for (auto dim : llvm::reverse(shape)) {
|
||||
this->strides.emplace_back(i32_val(stride));
|
||||
stride *= dim;
|
||||
}
|
||||
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
elems.append(strides.begin(), strides.end());
|
||||
return elems;
|
||||
}
|
||||
|
||||
SmallVector<Type> getTypes() const {
|
||||
SmallVector<Type> types;
|
||||
types.push_back(base.getType());
|
||||
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
||||
return types;
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
@@ -489,6 +535,16 @@ public:
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
Value ret = idx_val(0);
|
||||
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||
ret = add(ret, mul(offset, stride));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
@@ -671,6 +727,25 @@ public:
|
||||
return base;
|
||||
}
|
||||
|
||||
static SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
return SharedMemoryObject(/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.end()});
|
||||
}
|
||||
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
const SharedMemoryObject &smemObj,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = smemObj.getElems();
|
||||
auto types = smemObj.getTypes();
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
|
||||
protected:
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
@@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// %dst = extract_slice %src[%offsets]
|
||||
Location loc = op->getLoc();
|
||||
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||
assert(op.hasUnitStride() &&
|
||||
"Only unit stride supported by ExtractSliceOpConversion");
|
||||
|
||||
// axis > 0 will result in non-contiguous memory access if the result
|
||||
// tensor is an alias of the source tensor.
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
||||
|
||||
// Example:
|
||||
// %dst = extract_slice %src, %index {axis = 0}
|
||||
// src.shape = [11, 2, 3, 4, 1]
|
||||
// offset = %index * 2 * 3 * 4 * 1
|
||||
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto base = product<int64_t>(dstTy.getShape());
|
||||
auto baseVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||
Value offset = mul(adaptor.index(), baseVal);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
// newBase = base + offset
|
||||
// Triton support either static and dynamic offsets
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
auto mixedOffsets = op.getMixedOffsets();
|
||||
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i))
|
||||
offsetVals.emplace_back(adaptor.offsets()[i]);
|
||||
else
|
||||
offsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||
}
|
||||
// Compute the offset based on the original strides of the shared memory
|
||||
// object
|
||||
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
// newShape = rank_reduce(shape)
|
||||
// Triton only supports static tensor sizes
|
||||
SmallVector<Value, 4> strideVals;
|
||||
auto staticSizes = op.static_sizes();
|
||||
for (auto i = 0; i < op.static_sizes().size(); ++i) {
|
||||
if (op.getStaticSize(i) != 1) {
|
||||
strideVals.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
smemObj =
|
||||
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(blocked->shared)");
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
@@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||
// TODO: We should get less barriers if it is handled by membar pass
|
||||
@@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
rewriter.replaceOp(op, smemBase);
|
||||
// Barrier is not necessary.
|
||||
// The membar pass knows that it writes to shared memory and will handle it
|
||||
// properly.
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
int perPhase, int maxPhase, int elemBytes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, const Location &loc)
|
||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
@@ -2393,8 +2491,8 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cTileStride = tileShape[order[1]];
|
||||
sTileStride = tileShape[order[0]];
|
||||
cTileStride = smemStrides[order[0]];
|
||||
sTileStride = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -2497,8 +2595,7 @@ public:
|
||||
for (int i = 0; i < numPtr; ++i) {
|
||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||
cMatOffI = xor_(cMatOffI, phase);
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)),
|
||||
mul(sOff, i32_val(sTileStride)));
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
|
||||
}
|
||||
|
||||
return offs;
|
||||
@@ -2534,7 +2631,7 @@ public:
|
||||
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
|
||||
}
|
||||
}
|
||||
return offs;
|
||||
@@ -2574,7 +2671,7 @@ public:
|
||||
// To prevent out-of-bound access when tile is too small.
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2608,14 +2705,15 @@ public:
|
||||
Value ptr = getPtr(ptrIdx);
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
int sOffset =
|
||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||
Value sOffset =
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
|
||||
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||
PTXBuilder builder;
|
||||
|
||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||
// thread.
|
||||
auto resArgs = builder.newListOperand(4, "=r");
|
||||
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
||||
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
|
||||
|
||||
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
||||
->o("trans", needTrans /*predicate*/)
|
||||
@@ -2640,26 +2738,24 @@ public:
|
||||
needTrans) { // Use lds.32 to load tf32 matrices
|
||||
Value ptr2 = getPtr(ptrIdx + 1);
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem =
|
||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
int sOffsetArrElem = sMatStride * sMatShape;
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
|
||||
Value elems[4];
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[2] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
} else {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[1] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
}
|
||||
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
@@ -2680,9 +2776,11 @@ public:
|
||||
};
|
||||
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem =
|
||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
|
||||
std::array<Value, 4> i8v4Elems;
|
||||
std::array<Value, 4> i32Elems;
|
||||
@@ -2692,16 +2790,14 @@ public:
|
||||
Value i8Elems[4][4];
|
||||
Type elemTy = type::i8Ty(ctx);
|
||||
if (kOrder == 1) {
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
|
||||
for (int i = 0; i < 2; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
|
||||
|
||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||
for (int i = 2; i < 4; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
|
||||
i8Elems[i][j] =
|
||||
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -2710,16 +2806,14 @@ public:
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
}
|
||||
} else { // k first
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
|
||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
|
||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
|
||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
|
||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -2752,8 +2846,8 @@ private:
|
||||
int cMatShape;
|
||||
int sMatShape;
|
||||
|
||||
int cTileStride;
|
||||
int sTileStride;
|
||||
Value cTileStride;
|
||||
Value sTileStride;
|
||||
|
||||
bool needTrans;
|
||||
bool canUseLdmatrix;
|
||||
@@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $c to registers, returns a LLVM::Struct.
|
||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
||||
@@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, Value llTensor) const {
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = aTensorTy.getShape();
|
||||
|
||||
@@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper {
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
@@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value tensor, Value llTensor) {
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||
ValueTable hb;
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper {
|
||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
@@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
ValueTable &vals) const {
|
||||
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
Value warpId, ValueTable &vals) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -3507,10 +3601,10 @@ private:
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &vals, &ld2](int a, int b) {
|
||||
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape,
|
||||
matShape, perPhase, maxPhase, elemBytes,
|
||||
rewriter, typeConverter, loc);
|
||||
MMA16816SmemLoader loader(
|
||||
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||
|
||||
const int numPtrs = loader.getNumPtr();
|
||||
@@ -3519,8 +3613,8 @@ private:
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
|
||||
ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})),
|
||||
smemPtrTy);
|
||||
}
|
||||
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
@@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value res;
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
@@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = mmaHelper.loadA(src, adaptor.src());
|
||||
res = mmaHelper.loadA(src, smemObj);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
res = mmaHelper.loadB(src, smemObj);
|
||||
}
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
res =
|
||||
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
res =
|
||||
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
@@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
loadedA = adaptor.a();
|
||||
loadedB = adaptor.b();
|
||||
} else {
|
||||
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
||||
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
||||
SharedMemoryObject smemA =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
||||
SharedMemoryObject smemB =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
||||
loadedA = mmaHelper.loadA(op.a(), smemA);
|
||||
loadedB = mmaHelper.loadB(op.b(), smemB);
|
||||
}
|
||||
|
||||
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
||||
@@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
Value smem = smemObj.base;
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
int strideAM = isARow ? shape[1] : 1;
|
||||
int strideAK = isARow ? 1 : shape[0];
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
Value strideAM = isARow ? strides[0] : i32_val(1);
|
||||
Value strideAK = isARow ? i32_val(1) : strides[1];
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
@@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
offA0I = udiv(offA0I, i32_val(vecA));
|
||||
offA0I = xor_(offA0I, phaseA);
|
||||
offA0I = xor_(offA0I, i32_val(vecA));
|
||||
offA[i] =
|
||||
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
||||
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||
}
|
||||
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
@@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
Value pa = gep(f16PtrTy, thePtrA,
|
||||
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
|
||||
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||
mul(i32_val(stepAK), strideAK));
|
||||
Value pa = gep(f16PtrTy, thePtrA, offset);
|
||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||
Value ha = load(bitcast(pa, aPtrTy));
|
||||
// record lds that needs to be moved
|
||||
@@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
Value smem = smemObj.base;
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int vecB = sharedLayout.getVec();
|
||||
int strideBN = isBRow ? 1 : shape[0];
|
||||
int strideBK = isBRow ? shape[1] : 1;
|
||||
int strideB0 = isBRow ? strideBN : strideBK;
|
||||
int strideB1 = isBRow ? strideBK : strideBN;
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
@@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
offB0I = udiv(offB0I, i32_val(vecB));
|
||||
offB0I = xor_(offB0I, phaseB);
|
||||
offB0I = mul(offB0I, i32_val(vecB));
|
||||
offB[i] =
|
||||
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
||||
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||
}
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
@@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
|
||||
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||
Value pb = gep(f16PtrTy, thePtrB,
|
||||
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
|
||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||
mul(i32_val(stepBK), strideBK));
|
||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
@@ -4171,7 +4278,17 @@ public:
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||
@@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||
|
||||
// %dst
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
||||
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
||||
getTypeConverter()->getIndexType(),
|
||||
product<int64_t>(srcTy.getShape()));
|
||||
Value offset = mul(llIndex, dstBase);
|
||||
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
||||
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
SmallVector<Value, 4> srcStrides;
|
||||
for (auto i = 0; i < dstShape.size(); ++i) {
|
||||
if (i == axis) {
|
||||
offsetVals.emplace_back(llIndex);
|
||||
} else {
|
||||
offsetVals.emplace_back(i32_val(0));
|
||||
srcStrides.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
// Compute the offset based on the original dimensions of the shared memory
|
||||
// object
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
auto dstPtrTy =
|
||||
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
|
||||
// %mask
|
||||
SmallVector<Value> maskElems;
|
||||
@@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
@@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion
|
||||
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
||||
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
||||
threadsPerCTA[inOrder[1]];
|
||||
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
|
||||
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
||||
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
||||
|
||||
@@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||
i32_val(maxPhase));
|
||||
Value rowOffset =
|
||||
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
|
||||
// srcShape and smemObj.shape maybe different if smemObj is a
|
||||
// slice of the original shared memory object.
|
||||
// So we need to use the original shape to compute the offset
|
||||
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
|
||||
Value colOffset =
|
||||
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
||||
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
||||
@@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion
|
||||
auto numWords = vecBitWidth / bitWidth;
|
||||
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||
|
||||
// XXX(Keren): Tune CG and CA here.
|
||||
// Tune CG and CA here.
|
||||
auto byteWidth = bitWidth / 8;
|
||||
CacheModifier srcCacheModifier =
|
||||
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
|
||||
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
|
||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
Value baseOffset =
|
||||
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
|
||||
i32_val(baseOffsetCol));
|
||||
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
|
||||
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto wordElemIdx = wordIdx * numWordElems;
|
||||
auto ©AsyncOp =
|
||||
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
|
||||
auto *dstOperand = ptxBuilder.newAddrOperand(
|
||||
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
|
||||
auto *dstOperand =
|
||||
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
|
||||
auto *srcOperand =
|
||||
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
|
||||
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
|
||||
|
Reference in New Issue
Block a user