[Triton-MLIR][Backend] Port FMADot conversion for DotOp (#844)

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
Yan Chunwei
2022-11-09 12:57:50 +08:00
committed by GitHub
parent de5b84c476
commit 0c87360657
4 changed files with 240 additions and 5 deletions

View File

@@ -3201,10 +3201,7 @@ private:
ConversionPatternRewriter &rewriter) const;
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(false && "Not implemented yet.");
return failure();
}
ConversionPatternRewriter &rewriter) const;
};
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
@@ -4497,6 +4494,155 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
return rcds;
}
LogicalResult
DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();
auto threadId = getThreadId(rewriter, loc);
using ValueTable = std::map<std::pair<int, int>, Value>;
auto A = op.a();
auto B = op.b();
auto C = op.c();
auto D = op.getResult();
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto cTensorTy = C.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape();
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto cLayout = cTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto dLayout = dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto aOrder = aLayout.getOrder();
auto bOrder = bLayout.getOrder();
auto order = dLayout.getOrder();
bool isARow = aOrder[0] == 1;
bool isBRow = bOrder[0] == 1;
int strideAM = isARow ? aShape[1] : 1;
int strideAK = isARow ? 1 : aShape[0];
int strideBN = isBRow ? 1 : bShape[0];
int strideBK = isBRow ? bShape[1] : 1;
int strideA0 = isARow ? strideAK : strideAM;
int strideA1 = isARow ? strideAM : strideAK;
int strideB0 = isBRow ? strideBN : strideBK;
int strideB1 = isBRow ? strideBK : strideBN;
int lda = isARow ? strideAM : strideAK;
int ldb = isBRow ? strideBK : strideBN;
int aPerPhase = aLayout.getPerPhase();
int aMaxPhase = aLayout.getMaxPhase();
int bPerPhase = bLayout.getPerPhase();
int bMaxPhase = bLayout.getMaxPhase();
int aNumPtr = 8;
int bNumPtr = 8;
int NK = aShape[1];
auto shapePerCTA = getShapePerCTA(dLayout);
auto sizePerThread = getSizePerThread(dLayout);
Value _0 = i32_val(0);
Value mContig = i32_val(sizePerThread[order[1]]);
Value nContig = i32_val(sizePerThread[order[0]]);
// threadId in blocked layout
SmallVector<Value> threadIds;
{
int dim = cShape.size();
threadIds.resize(dim);
for (unsigned k = 0; k < dim - 1; k++) {
Value dimK = i32_val(shapePerCTA[order[k]]);
Value rem = urem(threadId, dimK);
threadId = udiv(threadId, dimK);
threadIds[order[k]] = rem;
}
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
threadIds[order[dim - 1]] = urem(threadId, dimK);
}
Value threadIdM = threadIds[0];
Value threadIdN = threadIds[1];
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
SmallVector<Value> aOff(aNumPtr);
for (int i = 0; i < aNumPtr; ++i) {
aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
}
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
Value offB1 = isBRow ? _0 : mul(threadIdN, nContig);
SmallVector<Value> bOff(bNumPtr);
for (int i = 0; i < bNumPtr; ++i) {
bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
}
auto aSmem = getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
auto bSmem = getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
Type f32PtrTy = ptr_ty(f32_ty);
SmallVector<Value> aPtrs(aNumPtr);
for (int i = 0; i < aNumPtr; ++i)
aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]);
SmallVector<Value> bPtrs(bNumPtr);
for (int i = 0; i < bNumPtr; ++i)
bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]);
ValueTable has, hbs;
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
SmallVector<Value> ret = cc;
// is this compatible with blocked layout?
for (unsigned k = 0; k < NK; k++) {
int z = 0;
for (unsigned i = 0; i < cShape[order[1]]; i += shapePerCTA[order[1]])
for (unsigned j = 0; j < cShape[order[0]]; j += shapePerCTA[order[0]])
for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii)
for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) {
unsigned m = order[0] == 1 ? i : j;
unsigned n = order[0] == 1 ? j : i;
unsigned mm = order[0] == 1 ? ii : jj;
unsigned nn = order[0] == 1 ? jj : ii;
if (!has.count({m + mm, k})) {
Value pa = gep(f32PtrTy, aPtrs[0],
i32_val((m + mm) * strideAM + k * strideAK));
Value va = load(pa);
has[{m + mm, k}] = va;
}
if (!hbs.count({n + nn, k})) {
Value pb = gep(f32PtrTy, bPtrs[0],
i32_val((n + nn) * strideBN + k * strideBK));
Value vb = load(pb);
hbs[{n + nn, k}] = vb;
}
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
hbs[{n + nn, k}], ret[z]);
++z;
}
}
auto res = getStructFromElements(
loc, ret, rewriter,
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
rewriter.replaceOp(op, res);
return success();
}
/// ====================== mma codegen end ============================
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,

View File

@@ -576,6 +576,14 @@ public:
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
// for FMA, should retain the blocked layout.
if (A.getElementType().isF32() && B.getElementType().isF32() &&
!dotOp.allowTF32())
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
@@ -629,4 +637,4 @@ public:
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}
}

View File

@@ -169,3 +169,65 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
[32, 32, 16, 4, 32, 32, 16],
[32, 16, 16, 4, 32, 32, 16],
[128, 8, 8, 4, 32, 32, 16],
[127, 41, 43, 4, 32, 32, 16],
])
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, a_mask)
b = tl.load(b_ptrs, b_mask)
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
accumulator += tl.dot(a, b, allow_tf32=False)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_k += BLOCK_SIZE_K
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](a, b, c,
M, N, K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
golden = torch.matmul(a, b)
torch.testing.assert_close(c, golden)

View File

@@ -811,3 +811,22 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
// CHECK: llvm.intr.fmuladd
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked>
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
tt.store %36, %28 : tensor<32x32xf32, #blocked>
return
}
}