[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)

This commit is contained in:
Keren Zhou
2022-12-03 11:14:12 -08:00
committed by GitHub
parent 8edfe813a5
commit f2fcaeabf3
5 changed files with 105 additions and 72 deletions

View File

@@ -81,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op)) {
// If the current op is an async wait, we insert a barrier op and sync
// previous reads and writes.
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());

View File

@@ -708,19 +708,19 @@ public:
Type elemTy = type::f32Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[2] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
} else {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[1] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
}
return {elems[0], elems[1], elems[2], elems[3]};

View File

@@ -3327,10 +3327,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
// TODO[Superjomn]: Find a better way to implement it.
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
static bool isDotHMMA(TensorType operand, int mmaVersion) {
auto elemTy = operand.getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
(elemTy.isF32() && mmaVersion >= 2) ||
(elemTy.isInteger(8) && mmaVersion >= 2);
}
@@ -3354,11 +3354,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
// TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of
// an Op instance.
bool allowTF32 = false;
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
mmaLayout.getVersion());
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, mmaLayout.getVersion());
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
@@ -3421,25 +3417,16 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
// is an attribute of DotOp.
bool allowTF32 = false;
bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32;
if (isFMADot) {
auto dotOpLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else
assert(false && "Unsupported dot operand layout found");
auto dotOpLayout = dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else {
assert(false && "Unsupported dot operand layout found");
}
@@ -3805,13 +3792,6 @@ public:
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
auto shape = type.getShape();
// TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it
// is bound to an Op instance.
bool allowTF32 = false;
bool isFMADot = type.getElementType().isF32() && !allowTF32 &&
layout.dyn_cast_or_null<DotOperandEncodingAttr>();
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
@@ -3835,37 +3815,39 @@ public:
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (isFMADot) { // for parent is blocked layout
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
auto vecSize = 1;
if (elemTy.getIntOrFloatBitWidth() == 16) {
vecSize = 2;
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
vecSize = 4;
} else {
assert(false && "Unsupported element type");
}
Type vecTy = vec_ty(elemTy, vecSize);
if (mmaLayout.getVersion() == 2) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, elemTy},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
} else {
assert(false && "Unsupported element type");
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, vecTy));
ctx, SmallVector<Type>(elems, targetTy));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, vecTy));
return struct_ty(SmallVector<Type>(elems, targetTy));
}
}
@@ -3995,10 +3977,10 @@ struct InsertSliceAsyncOpConversion
// %other
SmallVector<Value> otherElems;
if (llOther) {
// TODO(Keren): support "other" tensor.
// FIXME(Keren): always assume other is 0 for now
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
assert(false && "insert_slice_async: Other value not supported yet");
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}