[BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937)

This commit is contained in:
Keren Zhou
2022-12-03 11:14:12 -08:00
committed by GitHub
parent 8edfe813a5
commit f2fcaeabf3
5 changed files with 105 additions and 72 deletions

View File

@@ -81,9 +81,11 @@ void MembarAnalysis::transfer(Operation *op, RegionInfo *regionInfo,
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op)) {
// If the current op is an async wait, we insert a barrier op and sync
// previous reads and writes.
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
regionInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());

View File

@@ -708,19 +708,19 @@ public:
Type elemTy = type::f32Ty(ctx);
Type elemPtrTy = ptr_ty(elemTy);
if (kOrder == 1) {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[2] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
} else {
elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem)));
elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem)));
elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal));
elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal));
elems[1] =
load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr, sOffsetArrElemVal));
elems[3] =
load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem)));
load(gep(elemPtrTy, ptr2, sOffsetArrElemVal));
}
return {elems[0], elems[1], elems[2], elems[3]};

View File

@@ -3327,10 +3327,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
// TODO[Superjomn]: Find a better way to implement it.
static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) {
static bool isDotHMMA(TensorType operand, int mmaVersion) {
auto elemTy = operand.getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && allowTF32 && mmaVersion >= 2) ||
(elemTy.isF32() && mmaVersion >= 2) ||
(elemTy.isInteger(8) && mmaVersion >= 2);
}
@@ -3354,11 +3354,7 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
// TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of
// an Op instance.
bool allowTF32 = false;
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32,
mmaLayout.getVersion());
bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, mmaLayout.getVersion());
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
@@ -3421,25 +3417,16 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand(
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
// TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it
// is an attribute of DotOp.
bool allowTF32 = false;
bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32;
if (isFMADot) {
auto dotOpLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto blockedLayout = dotOpLayout.getParent().cast<BlockedEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else
assert(false && "Unsupported dot operand layout found");
auto dotOpLayout = dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else {
assert(false && "Unsupported dot operand layout found");
}
@@ -3805,13 +3792,6 @@ public:
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
auto shape = type.getShape();
// TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it
// is bound to an Op instance.
bool allowTF32 = false;
bool isFMADot = type.getElementType().isF32() && !allowTF32 &&
layout.dyn_cast_or_null<DotOperandEncodingAttr>();
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
@@ -3835,37 +3815,39 @@ public:
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (isFMADot) { // for parent is blocked layout
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
auto vecSize = 1;
if (elemTy.getIntOrFloatBitWidth() == 16) {
vecSize = 2;
} else if (elemTy.getIntOrFloatBitWidth() == 8) {
vecSize = 4;
} else {
assert(false && "Unsupported element type");
}
Type vecTy = vec_ty(elemTy, vecSize);
if (mmaLayout.getVersion() == 2) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, elemTy},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
} else {
assert(false && "Unsupported element type");
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, vecTy));
ctx, SmallVector<Type>(elems, targetTy));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, vecTy));
return struct_ty(SmallVector<Type>(elems, targetTy));
}
}
@@ -3995,10 +3977,10 @@ struct InsertSliceAsyncOpConversion
// %other
SmallVector<Value> otherElems;
if (llOther) {
// TODO(Keren): support "other" tensor.
// FIXME(Keren): always assume other is 0 for now
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
assert(false && "insert_slice_async: Other value not supported yet");
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}

View File

@@ -220,14 +220,17 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
[32, 32, 16, 4, 32, 32, 16],
[32, 16, 16, 4, 32, 32, 16],
[128, 8, 8, 4, 32, 32, 16],
# TODO[Superjomn]: fix it later
# [127, 41, 43, 4, 32, 32, 16],
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
[32, 32, 16, 4, 32, 32, 16, False],
[32, 32, 16, 4, 32, 32, 16, True],
[32, 16, 16, 4, 32, 32, 16, False],
[32, 16, 16, 4, 32, 32, 16, True],
[127, 41, 43, 4, 32, 32, 16, False],
[127, 41, 43, 4, 32, 32, 16, True],
[128, 8, 8, 4, 32, 32, 16, False],
[128, 8, 8, 4, 32, 32, 16, True]
])
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
@@ -236,6 +239,7 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
ALLOW_TF32: tl.constexpr
):
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -253,10 +257,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, a_mask)
b = tl.load(b_ptrs, b_mask)
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
accumulator += tl.dot(a, b, allow_tf32=False)
a = tl.load(a_ptrs, a_mask, other=0.0)
b = tl.load(b_ptrs, b_mask, other=0.0)
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_k += BLOCK_SIZE_K
@@ -267,6 +270,9 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
@@ -277,8 +283,12 @@ def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
if allow_tf32:
# TF32 is not accurate enough
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
else:
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))

View File

@@ -923,6 +923,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (f32, f32, f32, f32)
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (f32, f32, f32, f32)
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
%28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma>
%38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %38 : tensor<32x32xf32, #blocked>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32