[Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)

Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Da Yan
2022-11-10 13:57:27 +08:00
committed by GitHub
parent 8832e32683
commit 4946167241
29 changed files with 1227 additions and 507 deletions

View File

@@ -2326,6 +2326,19 @@ private:
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
// shared -> dot_operand if the result layout is blocked
Value lowerSharedToDotOperandBlocked(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blockedLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const;
};
void ConvertLayoutOpConversion::processReplica(
@@ -3011,6 +3024,7 @@ public:
Value i8Elems[4][4];
Type elemTy = type::i8Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
if (kOrder == 1) {
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 4; ++j)
@@ -3025,7 +3039,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty);
}
} else { // k first
for (int j = 0; j < 4; ++j)
@@ -3041,7 +3055,7 @@ public:
for (int e = 0; e < 4; ++e)
i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m],
i8Elems[m][e], i32_val(e));
i32Elems[m] = bitcast(i8v4Elems[m], i32_ty);
i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty);
}
}
@@ -3725,8 +3739,7 @@ struct MMA16816ConversionHelper {
loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen.
Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
return result;
return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
}
// Loading $b from smem to registers, returns a LLVM::Struct.
@@ -3963,31 +3976,14 @@ private:
}
};
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>();
assert(mmaLayout);
bool isOuter{};
{
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[1];
else // $b
K = dstTensorTy.getShape()[0];
isOuter = K == 1;
}
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
// is an attribute of DotOp.
bool allowTF32 = false;
@@ -4023,6 +4019,41 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
} else {
assert(false && "Unsupported mma layout found");
}
return res;
}
LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
bool isOuter{};
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
else // $b
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
isOuter = K == 1;
Value res;
if (auto mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
dotOperandLayout, isOuter);
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
assert(false && "Blocked layout is not supported yet");
} else {
assert(false && "Unsupported dot operand layout found");
}
rewriter.replaceOp(op, res);
return success();
@@ -4046,23 +4077,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
Value loadedA, loadedB, loadedC;
// We support two kinds of operand layouts: 1. both $a, $b are dot_operand
// layout, 2. both of them are shared layout.
if (ATensorTy.getEncoding().isa<DotOperandEncodingAttr>()) {
assert(BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");
loadedA = adaptor.a();
loadedB = adaptor.b();
} else {
SharedMemoryObject smemA =
getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
SharedMemoryObject smemB =
getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
loadedA = mmaHelper.loadA(op.a(), smemA);
loadedB = mmaHelper.loadB(op.b(), smemB);
}
assert(ATensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");
Value loadedA, loadedB, loadedC;
loadedA = adaptor.a();
loadedB = adaptor.b();
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
@@ -4753,20 +4774,26 @@ public:
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = type.getElementType();
auto vecSize = 1;
if (elemTy.getIntOrFloatBitWidth() == 16) {
vecSize = 2;
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
vecSize = 4;
} else {
assert(false && "Unsupported element type");
}
Type vecTy = vec_ty(elemTy, vecSize);
if (mmaLayout.getVersion() == 2) {
if (dot_op_layout.getOpIdx() == 0) { // $a
int elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt);
Type x2Ty = vec_ty(elemTy, 2);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, x2Ty));
ctx, SmallVector<Type>(elems, vecTy));
}
if (dot_op_layout.getOpIdx() == 1) { // $b
int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
return struct_ty(SmallVector<Type>(elems, vecTy));
}
}
@@ -4775,13 +4802,11 @@ public:
if (dot_op_layout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(type);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
return struct_ty(SmallVector<Type>(elems, vecTy));
}
if (dot_op_layout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(type);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
return struct_ty(SmallVector<Type>(elems, vecTy));
}
}

View File

@@ -221,6 +221,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
Attribute dEncoding = retType.cast<RankedTensorType>().getEncoding();
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
auto bType = adaptor.b().getType().cast<RankedTensorType>();
@@ -230,17 +231,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);