[Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features - Allow taking a block of tensor slice, as long as each dimension is contiguous (unit stride). - Fix some problems in `insert_slice_async`'s semantic. - More general verification for ops that return shared layout encoding. ## Known Limitations - `insert_slice_async` still uses the old semantic. May submit another PR later to support similar semantic like `tensor.extract_slice`. - No encoding verification for `tensor.extract_slice`. - 3d tensor ops are broken. - Strided accesses are not allowed. - May cause a little performance slowdown since we are passing strides as values but not constants (e.g., int). It would be difficult to pass strides as attributes when we have control flows. A block argument is possible to accept tensors with different strides.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
@@ -24,18 +25,18 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
||||
if (maybeSharedAllocationOp(op)) {
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
if (isSharedEncoding(result)) {
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (auto extractSliceOp = dyn_cast<triton::gpu::ExtractSliceOp>(op)) {
|
||||
// extract_slice %src, %index
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
} else if (auto insertSliceOp =
|
||||
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
} else {
|
||||
aliasInfo.insert(result);
|
||||
}
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
// extract_slice %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (auto insertSliceOp =
|
||||
dyn_cast<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||
// insert_slice_async %src, %dst, %index
|
||||
aliasInfo = AliasInfo(operands[1]->getValue());
|
||||
pessimistic = false;
|
||||
} else if (isSharedEncoding(result)) {
|
||||
aliasInfo.insert(result);
|
||||
pessimistic = false;
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "mlir/Analysis/Liveness.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -76,13 +77,13 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto axis = op.axis();
|
||||
|
||||
bool fast_reduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||
bool fastReduce = axis == 1; // FIXME(Qingyi): The fastest-changing dimension
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
if (fast_reduce) {
|
||||
if (fastReduce) {
|
||||
unsigned sizeInterWarps = srcLayout.getWarpsPerCTA()[axis];
|
||||
smemShape[axis] = sizeInterWarps;
|
||||
} else {
|
||||
@@ -123,7 +124,7 @@ private:
|
||||
// For example: %a = scf.if -> yield
|
||||
// %a must be allocated elsewhere by other operations.
|
||||
// FIXME(Keren): extract and insert are always alias for now
|
||||
if (!maybeSharedAllocationOp(op) || isa<triton::gpu::ExtractSliceOp>(op) ||
|
||||
if (!maybeSharedAllocationOp(op) || isa<tensor::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
|
||||
return;
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -43,8 +44,7 @@ void MembarAnalysis::dfsOperation(Operation *operation,
|
||||
void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
OpBuilder *builder) {
|
||||
if (isa<scf::ForOp>(op) || isa<scf::IfOp>(op) || isa<scf::YieldOp>(op) ||
|
||||
isa<triton::gpu::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu::AllocTensorOp>(op)) {
|
||||
isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op)) {
|
||||
// Do not insert barriers before control flow operations and
|
||||
// alloc/extract/insert
|
||||
// alloc is an allocation op without memory write.
|
||||
|
@@ -24,7 +24,8 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
||||
dialect->getTypeID() ==
|
||||
mlir::TypeID::get<arith::ArithmeticDialect>());
|
||||
mlir::TypeID::get<arith::ArithmeticDialect>() ||
|
||||
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
||||
}
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state) {
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -50,7 +51,8 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
namespace {
|
||||
|
||||
// Create a 32-bit integer constant.
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
||||
int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
IntegerAttr::get(i32ty, v));
|
||||
@@ -63,17 +65,17 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
}
|
||||
|
||||
// Create a index type constant.
|
||||
Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
TypeConverter *converter, int64_t value) {
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
}
|
||||
|
||||
// Create an integer constant of \param width bits.
|
||||
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
int64_t value) {
|
||||
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
||||
short width, int64_t value) {
|
||||
Type ty = builder.getIntegerType(width);
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, value));
|
||||
@@ -369,8 +371,8 @@ static T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape) {
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value ptr, Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
@@ -383,6 +385,50 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
struct SharedMemoryObject {
|
||||
Value base; // i32 ptr. The start address of the shared memory object.
|
||||
// We need to store strides as Values but not integers because the
|
||||
// extract_slice instruction can take a slice at artibary offsets.
|
||||
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
||||
// 32, we need to let the instruction that uses $a to be aware of that.
|
||||
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
||||
// we store strides into an attribute array of integers, the information
|
||||
// cannot pass through block argument assignment because attributes are
|
||||
// associated with operations but not Values.
|
||||
// TODO(Keren): We may need to figure out a way to store strides as integers
|
||||
// if we want to support more optimizations.
|
||||
SmallVector<Value>
|
||||
strides; // i32 int. The strides of the shared memory object.
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<Value> strides)
|
||||
: base(base), strides(strides.begin(), strides.end()) {}
|
||||
|
||||
SharedMemoryObject(Value base, ArrayRef<int64_t> shape, Location loc,
|
||||
ConversionPatternRewriter &rewriter)
|
||||
: base(base) {
|
||||
auto stride = 1;
|
||||
for (auto dim : llvm::reverse(shape)) {
|
||||
this->strides.emplace_back(i32_val(stride));
|
||||
stride *= dim;
|
||||
}
|
||||
this->strides = llvm::to_vector<4>(llvm::reverse(this->strides));
|
||||
}
|
||||
|
||||
SmallVector<Value> getElems() const {
|
||||
SmallVector<Value> elems;
|
||||
elems.push_back(base);
|
||||
elems.append(strides.begin(), strides.end());
|
||||
return elems;
|
||||
}
|
||||
|
||||
SmallVector<Type> getTypes() const {
|
||||
SmallVector<Type> types;
|
||||
types.push_back(base.getType());
|
||||
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
||||
return types;
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
@@ -489,6 +535,16 @@ public:
|
||||
return linear;
|
||||
}
|
||||
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
Value ret = idx_val(0);
|
||||
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||
ret = add(ret, mul(offset, stride));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
@@ -671,6 +727,25 @@ public:
|
||||
return base;
|
||||
}
|
||||
|
||||
static SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
return SharedMemoryObject(/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.end()});
|
||||
}
|
||||
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
const SharedMemoryObject &smemObj,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = smemObj.getElems();
|
||||
auto types = smemObj.getTypes();
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
|
||||
protected:
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
@@ -1734,46 +1809,63 @@ struct AllocTensorOpConversion
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto llvmElemTy =
|
||||
getTypeConverter()->convertType(resultTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, elemPtrTy, smemBase);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemObj =
|
||||
SharedMemoryObject(smemBase, resultTy.getShape(), loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// %dst = extract_slice %src[%offsets]
|
||||
Location loc = op->getLoc();
|
||||
auto srcTy = op.src().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcTy = op.source().getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||
assert(op.hasUnitStride() &&
|
||||
"Only unit stride supported by ExtractSliceOpConversion");
|
||||
|
||||
// axis > 0 will result in non-contiguous memory access if the result
|
||||
// tensor is an alias of the source tensor.
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
||||
|
||||
// Example:
|
||||
// %dst = extract_slice %src, %index {axis = 0}
|
||||
// src.shape = [11, 2, 3, 4, 1]
|
||||
// offset = %index * 2 * 3 * 4 * 1
|
||||
auto dstTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto base = product<int64_t>(dstTy.getShape());
|
||||
auto baseVal = createIndexAttrConstant(
|
||||
rewriter, loc, getTypeConverter()->getIndexType(), base);
|
||||
Value offset = mul(adaptor.index(), baseVal);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
|
||||
Value resultVal = gep(elemPtrTy, adaptor.src(), offset);
|
||||
rewriter.replaceOp(op, resultVal);
|
||||
// newBase = base + offset
|
||||
// Triton support either static and dynamic offsets
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.source(), rewriter);
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
auto mixedOffsets = op.getMixedOffsets();
|
||||
for (auto i = 0; i < mixedOffsets.size(); ++i) {
|
||||
if (op.isDynamicOffset(i))
|
||||
offsetVals.emplace_back(adaptor.offsets()[i]);
|
||||
else
|
||||
offsetVals.emplace_back(i32_val(op.getStaticOffset(i)));
|
||||
}
|
||||
// Compute the offset based on the original strides of the shared memory
|
||||
// object
|
||||
auto offset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
// newShape = rank_reduce(shape)
|
||||
// Triton only supports static tensor sizes
|
||||
SmallVector<Value, 4> strideVals;
|
||||
auto staticSizes = op.static_sizes();
|
||||
for (auto i = 0; i < op.static_sizes().size(); ++i) {
|
||||
if (op.getStaticSize(i) != 1) {
|
||||
strideVals.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
auto resTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
smemObj =
|
||||
SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -2262,8 +2354,9 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(blocked->shared)");
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
@@ -2309,6 +2402,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
|
||||
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
|
||||
smemBase = bitcast(smemBase, elemPtrTy);
|
||||
auto smemObj = SharedMemoryObject(smemBase, dstShape, loc, rewriter);
|
||||
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
|
||||
unsigned numWordsEachRep = product<unsigned>(wordsInEachRep);
|
||||
SmallVector<Value> wordVecs(numWordsEachRep);
|
||||
// TODO: We should get less barriers if it is handled by membar pass
|
||||
@@ -2369,8 +2464,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
rewriter.replaceOp(op, smemBase);
|
||||
// Barrier is not necessary.
|
||||
// The membar pass knows that it writes to shared memory and will handle it
|
||||
// properly.
|
||||
rewriter.replaceOp(op, retVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -2380,9 +2477,10 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
int perPhase, int maxPhase, int elemBytes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, const Location &loc)
|
||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
@@ -2393,8 +2491,8 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cTileStride = tileShape[order[1]];
|
||||
sTileStride = tileShape[order[0]];
|
||||
cTileStride = smemStrides[order[0]];
|
||||
sTileStride = smemStrides[order[1]];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -2497,8 +2595,7 @@ public:
|
||||
for (int i = 0; i < numPtr; ++i) {
|
||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||
cMatOffI = xor_(cMatOffI, phase);
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)),
|
||||
mul(sOff, i32_val(sTileStride)));
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
|
||||
}
|
||||
|
||||
return offs;
|
||||
@@ -2534,7 +2631,7 @@ public:
|
||||
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
|
||||
}
|
||||
}
|
||||
return offs;
|
||||
@@ -2574,7 +2671,7 @@ public:
|
||||
// To prevent out-of-bound access when tile is too small.
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride)));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2608,14 +2705,15 @@ public:
|
||||
Value ptr = getPtr(ptrIdx);
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
int sOffset =
|
||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||
Value sOffset =
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
|
||||
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||
PTXBuilder builder;
|
||||
|
||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||
// thread.
|
||||
auto resArgs = builder.newListOperand(4, "=r");
|
||||
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
||||
auto addrArg = builder.newAddrOperand(sOffsetPtr, "r");
|
||||
|
||||
auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4")
|
||||
->o("trans", needTrans /*predicate*/)
|
||||
@@ -2640,26 +2738,24 @@ public:
|
||||
needTrans) { // Use lds.32 to load tf32 matrices
|
||||
Value ptr2 = getPtr(ptrIdx + 1);
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem =
|
||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
int sOffsetArrElem = sMatStride * sMatShape;
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
|
||||
Value elems[4];
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[2] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
} else {
|
||||
elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[1] =
|
||||
load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[3] =
|
||||
load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
elems[0] = load(gep(elemTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemTy, ptr2, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] = load(gep(elemTy, ptr2, sOffsetArrElemVal));
|
||||
}
|
||||
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
@@ -2680,9 +2776,11 @@ public:
|
||||
};
|
||||
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem =
|
||||
matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride;
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
|
||||
std::array<Value, 4> i8v4Elems;
|
||||
std::array<Value, 4> i32Elems;
|
||||
@@ -2692,16 +2790,14 @@ public:
|
||||
Value i8Elems[4][4];
|
||||
Type elemTy = type::i8Ty(ctx);
|
||||
if (kOrder == 1) {
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
|
||||
for (int i = 0; i < 2; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset));
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], sOffsetElemVal));
|
||||
|
||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||
for (int i = 2; i < 4; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset));
|
||||
i8Elems[i][j] =
|
||||
load(gep(elemTy, ptrs[i - 2][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -2710,16 +2806,14 @@ public:
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
}
|
||||
} else { // k first
|
||||
Value offset = i32_val(sOffsetElem);
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset));
|
||||
i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset));
|
||||
offset = i32_val(sOffsetElem + sOffsetArrElem);
|
||||
i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], sOffsetElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset));
|
||||
i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], sOffsetArrElemVal));
|
||||
for (int j = 0; j < 4; ++j)
|
||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset));
|
||||
i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], sOffsetArrElemVal));
|
||||
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < 4; ++e)
|
||||
@@ -2752,8 +2846,8 @@ private:
|
||||
int cMatShape;
|
||||
int sMatShape;
|
||||
|
||||
int cTileStride;
|
||||
int sTileStride;
|
||||
Value cTileStride;
|
||||
Value sTileStride;
|
||||
|
||||
bool needTrans;
|
||||
bool canUseLdmatrix;
|
||||
@@ -2922,12 +3016,12 @@ struct DotOpMmaV1ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value A, Value llA, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value B, Value llB, Value thread, Value smem, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Loading $c to registers, returns a LLVM::Struct.
|
||||
Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const;
|
||||
@@ -3334,7 +3428,7 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, Value llTensor) const {
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj) const {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = aTensorTy.getShape();
|
||||
|
||||
@@ -3348,7 +3442,7 @@ struct MMA16816ConversionHelper {
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
@@ -3370,7 +3464,7 @@ struct MMA16816ConversionHelper {
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value tensor, Value llTensor) {
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj) {
|
||||
ValueTable hb;
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3380,7 +3474,7 @@ struct MMA16816ConversionHelper {
|
||||
int numRepN = getNumRepN(tensorTy, shape[1]);
|
||||
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
@@ -3485,10 +3579,10 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
private:
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout,
|
||||
int wpt, uint32_t kOrder, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, Value warpId,
|
||||
ValueTable &vals) const {
|
||||
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
Value warpId, ValueTable &vals) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -3507,10 +3601,10 @@ private:
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &vals, &ld2](int a, int b) {
|
||||
MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape,
|
||||
matShape, perPhase, maxPhase, elemBytes,
|
||||
rewriter, typeConverter, loc);
|
||||
MMA16816SmemLoader loader(
|
||||
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
SmallVector<Value> offs = loader.computeOffsets(warpId, lane);
|
||||
|
||||
const int numPtrs = loader.getNumPtr();
|
||||
@@ -3519,8 +3613,8 @@ private:
|
||||
|
||||
Type smemPtrTy = helper.getShemPtrTy();
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy);
|
||||
ptrs[i] = bitcast(gep(smemPtrTy, smemObj.base, ValueRange({offs[i]})),
|
||||
smemPtrTy);
|
||||
}
|
||||
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
@@ -3612,6 +3706,7 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value res;
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc),
|
||||
@@ -3620,21 +3715,21 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = mmaHelper.loadA(src, adaptor.src());
|
||||
res = mmaHelper.loadA(src, smemObj);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = mmaHelper.loadB(src, adaptor.src());
|
||||
res = mmaHelper.loadB(src, smemObj);
|
||||
}
|
||||
} else if (mmaLayout.getVersion() == 1) {
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOperandLayout.getOpIdx() == 0) {
|
||||
// operand $a
|
||||
res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
res =
|
||||
helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) {
|
||||
// operand $b
|
||||
res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc),
|
||||
adaptor.src(), loc, rewriter);
|
||||
res =
|
||||
helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
@@ -3671,8 +3766,12 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
loadedA = adaptor.a();
|
||||
loadedB = adaptor.b();
|
||||
} else {
|
||||
loadedA = mmaHelper.loadA(op.a(), adaptor.a());
|
||||
loadedB = mmaHelper.loadB(op.b(), adaptor.b());
|
||||
SharedMemoryObject smemA =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
||||
SharedMemoryObject smemB =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
||||
loadedA = mmaHelper.loadA(op.a(), smemA);
|
||||
loadedB = mmaHelper.loadB(op.b(), smemB);
|
||||
}
|
||||
|
||||
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
||||
@@ -3797,8 +3896,12 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
Value smem = smemObj.base;
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3818,10 +3921,10 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
int strideAM = isARow ? shape[1] : 1;
|
||||
int strideAK = isARow ? 1 : shape[0];
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
Value strideAM = isARow ? strides[0] : i32_val(1);
|
||||
Value strideAK = isARow ? i32_val(1) : strides[1];
|
||||
Value strideA0 = isARow ? strideAK : strideAM;
|
||||
Value strideA1 = isARow ? strideAM : strideAK;
|
||||
|
||||
int strideRepM = wpt[0] * fpw[0] * 8;
|
||||
int strideRepK = 1;
|
||||
@@ -3847,8 +3950,7 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
offA0I = udiv(offA0I, i32_val(vecA));
|
||||
offA0I = xor_(offA0I, phaseA);
|
||||
offA0I = xor_(offA0I, i32_val(vecA));
|
||||
offA[i] =
|
||||
add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
||||
offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1));
|
||||
}
|
||||
|
||||
Type f16x2Ty = vec_ty(f16_ty, 2);
|
||||
@@ -3877,8 +3979,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
|
||||
int stepAM = isARow ? m : m / numPtrA * numPtrA;
|
||||
int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k;
|
||||
Value pa = gep(f16PtrTy, thePtrA,
|
||||
i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK));
|
||||
Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM),
|
||||
mul(i32_val(stepAK), strideAK));
|
||||
Value pa = gep(f16PtrTy, thePtrA, offset);
|
||||
Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max<int>(vecA / 2, 1)), 3);
|
||||
Value ha = load(bitcast(pa, aPtrTy));
|
||||
// record lds that needs to be moved
|
||||
@@ -3915,8 +4018,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, Value llTensor, Value thread, Value smem, Location loc,
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// smem
|
||||
Value smem = smemObj.base;
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorTy.getShape();
|
||||
@@ -3929,10 +4036,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
int vecB = sharedLayout.getVec();
|
||||
int strideBN = isBRow ? 1 : shape[0];
|
||||
int strideBK = isBRow ? shape[1] : 1;
|
||||
int strideB0 = isBRow ? strideBN : strideBK;
|
||||
int strideB1 = isBRow ? strideBK : strideBN;
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
Value strideBK = isBRow ? strides[0] : i32_val(1);
|
||||
Value strideB0 = isBRow ? strideBN : strideBK;
|
||||
Value strideB1 = isBRow ? strideBK : strideBN;
|
||||
int strideRepN = wpt[1] * fpw[1] * 8;
|
||||
int strideRepK = 1;
|
||||
|
||||
@@ -3957,8 +4064,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
offB0I = udiv(offB0I, i32_val(vecB));
|
||||
offB0I = xor_(offB0I, phaseB);
|
||||
offB0I = mul(offB0I, i32_val(vecB));
|
||||
offB[i] =
|
||||
add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
||||
offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1));
|
||||
}
|
||||
|
||||
Type f16PtrTy = ptr_ty(f16_ty);
|
||||
@@ -3979,8 +4085,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
|
||||
int stepBN = isBRow ? n / numPtrB * numPtrB : n;
|
||||
int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB);
|
||||
Value pb = gep(f16PtrTy, thePtrB,
|
||||
i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK));
|
||||
Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN),
|
||||
mul(i32_val(stepBK), strideBK));
|
||||
Value pb = gep(f16PtrTy, thePtrB, offset);
|
||||
Value hb =
|
||||
load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3)));
|
||||
// record lds that needs to be moved
|
||||
@@ -4171,7 +4278,17 @@ public:
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
for (auto i = 0; i < rank; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto mmaLayout = layout.dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||
@@ -4309,15 +4426,26 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||
|
||||
// %dst
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
assert(axis == 0 && "insert_slice_async: Only axis=0 is supported for now");
|
||||
auto dstBase = createIndexAttrConstant(rewriter, loc,
|
||||
getTypeConverter()->getIndexType(),
|
||||
product<int64_t>(srcTy.getShape()));
|
||||
Value offset = mul(llIndex, dstBase);
|
||||
auto dstPtrTy = LLVM::LLVMPointerType::get(
|
||||
getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
Value dstPtrBase = gep(dstPtrTy, llDst, offset);
|
||||
SmallVector<Value, 4> offsetVals;
|
||||
SmallVector<Value, 4> srcStrides;
|
||||
for (auto i = 0; i < dstShape.size(); ++i) {
|
||||
if (i == axis) {
|
||||
offsetVals.emplace_back(llIndex);
|
||||
} else {
|
||||
offsetVals.emplace_back(i32_val(0));
|
||||
srcStrides.emplace_back(smemObj.strides[i]);
|
||||
}
|
||||
}
|
||||
// Compute the offset based on the original dimensions of the shared memory
|
||||
// object
|
||||
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
|
||||
auto dstPtrTy =
|
||||
ptr_ty(getTypeConverter()->convertType(resTy.getElementType()), 3);
|
||||
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
|
||||
|
||||
// %mask
|
||||
SmallVector<Value> maskElems;
|
||||
@@ -4345,11 +4473,10 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we need to swizzle over
|
||||
// elements across phases. If perPhase * maxPhase == threadsPerCTA,
|
||||
// elements across phases. If perPhase * maxPhase <= threadsPerCTA,
|
||||
// swizzle is not allowd
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
@@ -4377,7 +4504,6 @@ struct InsertSliceAsyncOpConversion
|
||||
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
|
||||
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
|
||||
threadsPerCTA[inOrder[1]];
|
||||
auto baseOffset = (baseOffsetRow * srcShape[inOrder[0]] + baseOffsetCol);
|
||||
auto tileVecIdxCol = vecIdxCol % numVecCols;
|
||||
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
|
||||
|
||||
@@ -4399,8 +4525,10 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
|
||||
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
|
||||
i32_val(maxPhase));
|
||||
Value rowOffset =
|
||||
mul(srcIdx[inOrder[1]], i32_val(srcShape[inOrder[0]]));
|
||||
// srcShape and smemObj.shape maybe different if smemObj is a
|
||||
// slice of the original shared memory object.
|
||||
// So we need to use the original shape to compute the offset
|
||||
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
|
||||
Value colOffset =
|
||||
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
|
||||
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
|
||||
@@ -4420,21 +4548,25 @@ struct InsertSliceAsyncOpConversion
|
||||
auto numWords = vecBitWidth / bitWidth;
|
||||
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
|
||||
|
||||
// XXX(Keren): Tune CG and CA here.
|
||||
// Tune CG and CA here.
|
||||
auto byteWidth = bitWidth / 8;
|
||||
CacheModifier srcCacheModifier =
|
||||
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
|
||||
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
|
||||
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
|
||||
|
||||
auto tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
|
||||
Value baseOffset =
|
||||
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
|
||||
i32_val(baseOffsetCol));
|
||||
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
|
||||
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
|
||||
PTXBuilder ptxBuilder;
|
||||
auto wordElemIdx = wordIdx * numWordElems;
|
||||
auto ©AsyncOp =
|
||||
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
|
||||
auto *dstOperand = ptxBuilder.newAddrOperand(
|
||||
tileOffset, "r", (wordElemIdx + baseOffset) * resByteWidth);
|
||||
auto *dstOperand =
|
||||
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
|
||||
auto *srcOperand =
|
||||
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
|
||||
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
|
||||
|
@@ -66,4 +66,4 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
add_mlir_dialect_library(TritonGPUIR
|
||||
Dialect.cpp
|
||||
Traits.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUTableGen
|
||||
|
@@ -474,7 +474,7 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
|
||||
ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
SmallVector<OpAsmParser::OperandType, 8> allOperands;
|
||||
Type srcType, dstType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
@@ -489,14 +489,27 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
|
||||
operandTypes.push_back(dstType); // dst
|
||||
operandTypes.push_back(
|
||||
IntegerType::get(parser.getBuilder().getContext(), 32)); // index
|
||||
if (allOperands.size() >= 4)
|
||||
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 4) {
|
||||
operandTypes.push_back(triton::getI1SameShape(srcType)); // mask
|
||||
if (allOperands.size() >= 5)
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 5) {
|
||||
operandTypes.push_back(triton::getPointeeType(srcType)); // other
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
InsertSliceAsyncOp::operand_segment_sizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
parser.getBuilder().getI32VectorAttr({1, 1, 1, hasMask, hasOther}));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -504,39 +517,16 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
InsertSliceAsyncOp insertSliceAsyncOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceAsyncOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(insertSliceAsyncOp->getAttrs(),
|
||||
/*elidedAttrs=*/{});
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceAsyncOp->getAttrs(),
|
||||
{insertSliceAsyncOp.operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
|
||||
printer << " -> ";
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult ExtractSliceOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, llvm::Optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, mlir::DictionaryAttr attributes,
|
||||
::mlir::RegionRange regions,
|
||||
llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
auto srcType = operands[0].getType().cast<RankedTensorType>();
|
||||
auto encoding = srcType.getEncoding();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
if (axis < 0 || (size_t)axis > srcShape.size())
|
||||
return failure();
|
||||
SmallVector<int64_t, 4> dstShape;
|
||||
for (size_t i = 0; i < srcShape.size(); i++)
|
||||
if (i != (size_t)axis)
|
||||
dstShape.push_back(srcShape[i]);
|
||||
auto returnType =
|
||||
RankedTensorType::get(dstShape, srcType.getElementType(), encoding);
|
||||
inferredReturnTypes.assign({returnType});
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotOperand Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -631,32 +621,6 @@ void TritonGPUDialect::initialize() {
|
||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(InsertSliceAsyncOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError(
|
||||
"insert_slice_async should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractSliceOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("extract_slice should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AllocTensorOp op) {
|
||||
if (!isSharedEncoding(op.getResult())) {
|
||||
return op.emitOpError("alloc_tensor should return a shared memory tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
||||
|
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
14
lib/Dialect/TritonGPU/IR/Traits.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
for (auto result : op->getResults())
|
||||
if (!isSharedEncoding(result))
|
||||
return op->emitOpError() << "requires all results to be shared encoding";
|
||||
|
||||
return success();
|
||||
};
|
@@ -111,37 +111,41 @@ public:
|
||||
// cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2))
|
||||
auto insert_slice = dyn_cast<triton::gpu::InsertSliceAsyncOp>(arg);
|
||||
if (insert_slice) {
|
||||
auto newType = op->getResult(0).getType();
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, insert_slice.dst());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
|
||||
op, newType, insert_slice.src(), new_arg.getResult(),
|
||||
op, newType, insert_slice.src(), newArg.getResult(),
|
||||
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
|
||||
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
|
||||
insert_slice.axis());
|
||||
insert_slice.cache(), insert_slice.evict(),
|
||||
insert_slice.isVolatile(), insert_slice.axis());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) ->extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
auto origType = extract_slice.src().getType().cast<RankedTensorType>();
|
||||
auto origType = extract_slice.source().getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(),
|
||||
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
|
||||
auto new_arg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.src());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
|
||||
op, new_arg.getResult(), extract_slice.index(), extract_slice.axis());
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.source());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
extract_slice.static_strides());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type2, x)
|
||||
@@ -198,7 +202,7 @@ static LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
inline bool expensive_to_remat(Operation *op) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::LoadOp, triton::StoreOp,
|
||||
triton::AtomicRMWOp, triton::AtomicCASOp, triton::DotOp>(op))
|
||||
return true;
|
||||
|
@@ -339,14 +339,20 @@ void LoopPipeliner::emitPrologue() {
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
} // for (int stage = 0; stage < numStages - 1; ++stage)
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
// async.wait & extract_slice
|
||||
builder.create<triton::gpu::AsyncWaitOp>(loads[0].getLoc(),
|
||||
loads.size() * (numStages - 2));
|
||||
loopIterIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
for (Value loadOp : loads) {
|
||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
loadOp.getLoc(), loadsMapping[loadOp].getType(),
|
||||
loadStageBuffer[loadOp][numStages - 1], loopIterIdx, /*axis*/ 0);
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{intAttr(0), intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
loadsExtract[loadOp] = extractSlice;
|
||||
}
|
||||
// bump up loopIterIdx, this is used for getting the correct slice for the
|
||||
@@ -477,6 +483,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
||||
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
||||
|
||||
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
|
||||
|
||||
for (Operation *op : orderedDeps) {
|
||||
Operation *nextOp = nullptr;
|
||||
@@ -503,9 +513,14 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
nextMapping.lookupOrDefault(loadOp.other()), loadOp.cache(),
|
||||
loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0);
|
||||
nextBuffers.push_back(insertAsyncOp);
|
||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
op->getLoc(), loadsMapping[loadOp].getType(), insertAsyncOp,
|
||||
extractSliceIndex, /*axis*/ 0);
|
||||
auto sliceType = loadsMapping[loadOp].getType().cast<RankedTensorType>();
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, intAttr(0), intAttr(0)},
|
||||
SmallVector<OpFoldResult>{intAttr(1),
|
||||
intAttr(sliceType.getShape()[0]),
|
||||
intAttr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1), intAttr(1)});
|
||||
extractSlices.push_back(nextOp->getResult(0));
|
||||
} else
|
||||
nextOp = builder.clone(*op, nextMapping);
|
||||
|
Reference in New Issue
Block a user