[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)
This commit is contained in:
@@ -81,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op)) {
|
||||
// If the current op is an async wait, we insert a barrier op and sync
|
||||
// previous reads and writes.
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode())) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
regionInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
|
@@ -708,19 +708,19 @@ public:
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
Type elemPtrTy = ptr_ty(elemTy);
|
||||
if (kOrder == 1) {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[2] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
} else {
|
||||
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
|
||||
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
|
||||
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
|
||||
elems[1] =
|
||||
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
|
||||
elems[3] =
|
||||
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
|
||||
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
|
||||
}
|
||||
return {elems[0], elems[1], elems[2], elems[3]};
|
||||
|
||||
|
@@ -3327,10 +3327,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
// We cannot get both the operand types(in TypeConverter), here we assume the
|
||||
// types of both the operands are identical here.
|
||||
// TODO[Superjomn]: Find a better way to implement it.
|
||||
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
|
||||
static bool isDotHMMA(TensorType operand, int mmaVersion) {
|
||||
auto elemTy = operand.getElementType();
|
||||
return elemTy.isF16() || elemTy.isBF16() ||
|
||||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
|
||||
(elemTy.isF32() && mmaVersion >= 2) ||
|
||||
(elemTy.isInteger(8) && mmaVersion >= 2);
|
||||
}
|
||||
|
||||
@@ -3354,11 +3354,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
Value src = op.src();
|
||||
Value dst = op.result();
|
||||
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
|
||||
// TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of
|
||||
// an Op instance.
|
||||
bool allowTF32 = false;
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
|
||||
mmaLayout.getVersion());
|
||||
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, mmaLayout.getVersion());
|
||||
|
||||
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
|
||||
Value res;
|
||||
@@ -3421,25 +3417,16 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
|
||||
} else if (auto blockedLayout =
|
||||
dotOperandLayout.getParent()
|
||||
.dyn_cast_or_null<BlockedEncodingAttr>()) {
|
||||
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
|
||||
// is an attribute of DotOp.
|
||||
bool allowTF32 = false;
|
||||
bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32;
|
||||
if (isFMADot) {
|
||||
auto dotOpLayout =
|
||||
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
|
||||
DotOpFMAConversionHelper helper(blockedLayout);
|
||||
auto thread = getThreadId(rewriter, loc);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
} else { // $b
|
||||
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
}
|
||||
} else
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
auto dotOpLayout = dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
DotOpFMAConversionHelper helper(blockedLayout);
|
||||
auto thread = getThreadId(rewriter, loc);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
} else { // $b
|
||||
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
}
|
||||
@@ -3805,13 +3792,6 @@ public:
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
auto shape = type.getShape();
|
||||
|
||||
// TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it
|
||||
// is bound to an Op instance.
|
||||
bool allowTF32 = false;
|
||||
bool isFMADot = type.getElementType().isF32() && !allowTF32 &&
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>();
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
@@ -3835,37 +3815,39 @@ public:
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto dotOpLayout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
if (isFMADot) { // for parent is blocked layout
|
||||
if (dotOpLayout.getParent()
|
||||
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
|
||||
int numElemsPerThread =
|
||||
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
||||
|
||||
} else { // for parent is MMA layout
|
||||
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
auto vecSize = 1;
|
||||
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
||||
vecSize = 2;
|
||||
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
|
||||
vecSize = 4;
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
if (mmaLayout.getVersion() == 2) {
|
||||
const llvm::DenseMap<int, Type> targetTyMap = {
|
||||
{32, elemTy},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
Type targetTy;
|
||||
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
|
||||
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, vecTy));
|
||||
ctx, SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
int elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||
return struct_ty(SmallVector<Type>(elems, vecTy));
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3995,10 +3977,10 @@ struct InsertSliceAsyncOpConversion
|
||||
// %other
|
||||
SmallVector<Value> otherElems;
|
||||
if (llOther) {
|
||||
// TODO(Keren): support "other" tensor.
|
||||
// FIXME(Keren): always assume other is 0 for now
|
||||
// It's not necessary for now because the pipeline pass will skip
|
||||
// generating insert_slice_async if the load op has any "other" tensor.
|
||||
assert(false && "insert_slice_async: Other value not supported yet");
|
||||
// assert(false && "insert_slice_async: Other value not supported yet");
|
||||
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
@@ -220,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||
[32, 32, 16, 4, 32, 32, 16],
|
||||
[32, 16, 16, 4, 32, 32, 16],
|
||||
[128, 8, 8, 4, 32, 32, 16],
|
||||
# TODO[Superjomn]: fix it later
|
||||
# [127, 41, 43, 4, 32, 32, 16],
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
|
||||
[32, 32, 16, 4, 32, 32, 16, False],
|
||||
[32, 32, 16, 4, 32, 32, 16, True],
|
||||
[32, 16, 16, 4, 32, 32, 16, False],
|
||||
[32, 16, 16, 4, 32, 32, 16, True],
|
||||
[127, 41, 43, 4, 32, 32, 16, False],
|
||||
[127, 41, 43, 4, 32, 32, 16, True],
|
||||
[128, 8, 8, 4, 32, 32, 16, False],
|
||||
[128, 8, 8, 4, 32, 32, 16, True]
|
||||
])
|
||||
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
@@ -236,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
@@ -253,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
||||
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, a_mask)
|
||||
b = tl.load(b_ptrs, b_mask)
|
||||
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
||||
accumulator += tl.dot(a, b, allow_tf32=False)
|
||||
a = tl.load(a_ptrs, a_mask, other=0.0)
|
||||
b = tl.load(b_ptrs, b_mask, other=0.0)
|
||||
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_k += BLOCK_SIZE_K
|
||||
@@ -267,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
# Configure the pytorch counterpart
|
||||
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
|
||||
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
@@ -277,8 +283,12 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
if allow_tf32:
|
||||
# TF32 is not accurate enough
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
|
||||
else:
|
||||
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
|
||||
|
@@ -923,6 +923,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// -----
|
||||
|
||||
#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: matmul_tf32dot
|
||||
func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (f32, f32, f32, f32)
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (f32, f32, f32, f32)
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
|
||||
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
|
||||
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
|
||||
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %38 : tensor<32x32xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
|
Reference in New Issue
Block a user