[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -2326,6 +2326,19 @@ private:
|
||||
LogicalResult
|
||||
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
||||
|
||||
// shared -> dot_operand if the result layout is blocked
|
||||
Value lowerSharedToDotOperandBlocked(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
|
||||
};
|
||||
|
||||
void ConvertLayoutOpConversion::processReplica(
|
||||
@@ -3011,6 +3024,7 @@ public:
|
||||
Value i8Elems[4][4];
|
||||
Type elemTy = type::i8Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
||||
if (kOrder == 1) {
|
||||
for (int i = 0; i < 2; ++i)
|
||||
for (int j = 0; j < 4; ++j)
|
||||
@@ -3025,7 +3039,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty);
|
||||
}
|
||||
} else { // k first
|
||||
for (int j = 0; j < 4; ++j)
|
||||
@@ -3041,7 +3055,7 @@ public:
|
||||
for (int e = 0; e < 4; ++e)
|
||||
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
|
||||
i8Elems[m][e], i32_val(e));
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
|
||||
i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3725,8 +3739,7 @@ struct MMA16816ConversionHelper {
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
return result;
|
||||
return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
}
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
@@ -3963,31 +3976,14 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
|
||||
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
|
||||
MmaEncodingAttr mmaLayout =
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
|
||||
assert(mmaLayout);
|
||||
|
||||
bool isOuter{};
|
||||
{
|
||||
int K{};
|
||||
if (dotOperandLayout.getOpIdx() == 0) // $a
|
||||
K = dstTensorTy.getShape()[1];
|
||||
else // $b
|
||||
K = dstTensorTy.getShape()[0];
|
||||
isOuter = K == 1;
|
||||
}
|
||||
|
||||
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
||||
// is an attribute of DotOp.
|
||||
bool allowTF32 = false;
|
||||
@@ -4023,6 +4019,41 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcTensorTy = src.getType().cast<RankedTensorType>();
|
||||
auto dotOperandLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
|
||||
bool isOuter{};
|
||||
int K{};
|
||||
if (dotOperandLayout.getOpIdx() == 0) // $a
|
||||
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
|
||||
else // $b
|
||||
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
|
||||
isOuter = K == 1;
|
||||
|
||||
Value res;
|
||||
if (auto mmaLayout =
|
||||
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
|
||||
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
|
||||
dotOperandLayout, isOuter);
|
||||
} else if (auto blockedLayout =
|
||||
dotOperandLayout.getParent()
|
||||
.dyn_cast_or_null<BlockedEncodingAttr>()) {
|
||||
assert(false && "Blocked layout is not supported yet");
|
||||
} else {
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
@@ -4046,23 +4077,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
|
||||
Value loadedA, loadedB, loadedC;
|
||||
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
|
||||
// layout, 2. both of them are shared layout.
|
||||
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
|
||||
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
"Both $a and %b should be DotOperand layout.");
|
||||
loadedA = adaptor.a();
|
||||
loadedB = adaptor.b();
|
||||
} else {
|
||||
SharedMemoryObject smemA =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
||||
SharedMemoryObject smemB =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
||||
loadedA = mmaHelper.loadA(op.a(), smemA);
|
||||
loadedB = mmaHelper.loadB(op.b(), smemB);
|
||||
}
|
||||
assert(ATensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
|
||||
"Both $a and %b should be DotOperand layout.");
|
||||
|
||||
Value loadedA, loadedB, loadedC;
|
||||
loadedA = adaptor.a();
|
||||
loadedB = adaptor.b();
|
||||
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
|
||||
|
||||
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
|
||||
@@ -4753,20 +4774,26 @@ public:
|
||||
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = type.getElementType();
|
||||
auto vecSize = 1;
|
||||
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
||||
vecSize = 2;
|
||||
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
|
||||
vecSize = 4;
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
|
||||
if (dot_op_layout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, x2Ty));
|
||||
ctx, SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
if (dot_op_layout.getOpIdx() == 1) { // $b
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4775,13 +4802,11 @@ public:
|
||||
|
||||
if (dot_op_layout.getOpIdx() == 0) { // $a
|
||||
int elems = helper.numElemsPerThreadA(type);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
if (dot_op_layout.getOpIdx() == 1) { // $b
|
||||
int elems = helper.numElemsPerThreadB(type);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -221,6 +221,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
Attribute dEncoding = retType.cast<RankedTensorType>().getEncoding();
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
||||
auto bType = adaptor.b().getType().cast<RankedTensorType>();
|
||||
@@ -230,17 +231,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
return failure();
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
|
Reference in New Issue
Block a user