[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)
This commit is contained in:
@@ -81,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op)) {
|
||||
// If the current op is an async wait, we insert a barrier op and sync
|
||||
// previous reads and writes.
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode())) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
regionInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
|
@@ -708,19 +708,19 @@ public:
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[2] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
} else {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[1] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
}
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
|
||||
|
@@ -3327,10 +3327,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
// TODO[Superjomn]: Find a better way to implement it.
|
||||
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
|
||||
static bool isDotHMMA(TensorType operand, int mmaVersion) {
|
||||
auto elemTy = operand.getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
|
||||
(elemTy.isF32() && mmaVersion >= 2) ||
|
||||
(elemTy.isInteger(8) && mmaVersion >= 2);
|
||||
}
|
||||
|
||||
@@ -3354,11 +3354,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
// TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of
|
||||
// an Op instance.
|
||||
bool allowTF32 = false;
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
|
||||
mmaLayout.getVersion());
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, mmaLayout.getVersion());
|
||||
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value res;
|
||||
@@ -3421,25 +3417,16 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
} else if (auto blockedLayout =
|
||||
dotOperandLayout.getParent()
|
||||
.dyn_cast_or_null<BlockedEncodingAttr>()) {
|
||||
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
||||
// is an attribute of DotOp.
|
||||
bool allowTF32 = false;
|
||||
bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32;
|
||||
if (isFMADot) {
|
||||
auto dotOpLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
DotOpFMAConversionHelper helper(blockedLayout);
|
||||
auto thread = getThreadId(rewriter, loc);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
} else { // $b
|
||||
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
}
|
||||
} else
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
auto dotOpLayout = dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
DotOpFMAConversionHelper helper(blockedLayout);
|
||||
auto thread = getThreadId(rewriter, loc);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
} else { // $b
|
||||
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
}
|
||||
@@ -3805,13 +3792,6 @@ public:
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
auto shape = type.getShape();
|
||||
|
||||
// TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it
|
||||
// is bound to an Op instance.
|
||||
bool allowTF32 = false;
|
||||
bool isFMADot = type.getElementType().isF32() && !allowTF32 &&
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>();
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
@@ -3835,37 +3815,39 @@ public:
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto dotOpLayout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
if (isFMADot) { // for parent is blocked layout
|
||||
if (dotOpLayout.getParent()
|
||||
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
|
||||
int numElemsPerThread =
|
||||
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
||||
|
||||
} else { // for parent is MMA layout
|
||||
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
auto vecSize = 1;
|
||||
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
||||
vecSize = 2;
|
||||
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
|
||||
vecSize = 4;
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
const llvm::DenseMap<int, Type> targetTyMap = {
|
||||
{32, elemTy},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
Type targetTy;
|
||||
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
|
||||
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, vecTy));
|
||||
ctx, SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3995,10 +3977,10 @@ struct InsertSliceAsyncOpConversion
|
||||
// %other
|
||||
SmallVector<Value> otherElems;
|
||||
if (llOther) {
|
||||
// TODO(Keren): support "other" tensor.
|
||||
// FIXME(Keren): always assume other is 0 for now
|
||||
// It's not necessary for now because the pipeline pass will skip
|
||||
// generating insert_slice_async if the load op has any "other" tensor.
|
||||
assert(false && "insert_slice_async: Other value not supported yet");
|
||||
// assert(false && "insert_slice_async: Other value not supported yet");
|
||||
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
Reference in New Issue
Block a user