[Triton-MLIR][Backend] Port FMADot conversion for DotOp (#844)
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
This commit is contained in:
@@ -3201,10 +3201,7 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
assert(false && "Not implemented yet.");
|
||||
return failure();
|
||||
}
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
// Helper for conversion of DotOp with mma<version=1>, that is sm<80
|
||||
@@ -4497,6 +4494,155 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
||||
return rcds;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto loc = op.getLoc();
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
|
||||
using ValueTable = std::map<std::pair<int, int>, Value>;
|
||||
|
||||
auto A = op.a();
|
||||
auto B = op.b();
|
||||
auto C = op.c();
|
||||
auto D = op.getResult();
|
||||
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = C.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = D.getType().cast<RankedTensorType>();
|
||||
|
||||
auto aShape = aTensorTy.getShape();
|
||||
auto bShape = bTensorTy.getShape();
|
||||
auto cShape = cTensorTy.getShape();
|
||||
|
||||
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto cLayout = cTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto dLayout = dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
|
||||
auto aOrder = aLayout.getOrder();
|
||||
auto bOrder = bLayout.getOrder();
|
||||
|
||||
auto order = dLayout.getOrder();
|
||||
|
||||
bool isARow = aOrder[0] == 1;
|
||||
bool isBRow = bOrder[0] == 1;
|
||||
|
||||
int strideAM = isARow ? aShape[1] : 1;
|
||||
int strideAK = isARow ? 1 : aShape[0];
|
||||
int strideBN = isBRow ? 1 : bShape[0];
|
||||
int strideBK = isBRow ? bShape[1] : 1;
|
||||
int strideA0 = isARow ? strideAK : strideAM;
|
||||
int strideA1 = isARow ? strideAM : strideAK;
|
||||
int strideB0 = isBRow ? strideBN : strideBK;
|
||||
int strideB1 = isBRow ? strideBK : strideBN;
|
||||
int lda = isARow ? strideAM : strideAK;
|
||||
int ldb = isBRow ? strideBK : strideBN;
|
||||
int aPerPhase = aLayout.getPerPhase();
|
||||
int aMaxPhase = aLayout.getMaxPhase();
|
||||
int bPerPhase = bLayout.getPerPhase();
|
||||
int bMaxPhase = bLayout.getMaxPhase();
|
||||
int aNumPtr = 8;
|
||||
int bNumPtr = 8;
|
||||
int NK = aShape[1];
|
||||
|
||||
auto shapePerCTA = getShapePerCTA(dLayout);
|
||||
|
||||
auto sizePerThread = getSizePerThread(dLayout);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
|
||||
Value mContig = i32_val(sizePerThread[order[1]]);
|
||||
Value nContig = i32_val(sizePerThread[order[0]]);
|
||||
|
||||
// threadId in blocked layout
|
||||
SmallVector<Value> threadIds;
|
||||
{
|
||||
int dim = cShape.size();
|
||||
threadIds.resize(dim);
|
||||
for (unsigned k = 0; k < dim - 1; k++) {
|
||||
Value dimK = i32_val(shapePerCTA[order[k]]);
|
||||
Value rem = urem(threadId, dimK);
|
||||
threadId = udiv(threadId, dimK);
|
||||
threadIds[order[k]] = rem;
|
||||
}
|
||||
Value dimK = i32_val(shapePerCTA[order[dim - 1]]);
|
||||
threadIds[order[dim - 1]] = urem(threadId, dimK);
|
||||
}
|
||||
|
||||
Value threadIdM = threadIds[0];
|
||||
Value threadIdN = threadIds[1];
|
||||
|
||||
Value offA0 = isARow ? _0 : mul(threadIdM, mContig);
|
||||
Value offA1 = isARow ? mul(threadIdM, mContig) : _0;
|
||||
SmallVector<Value> aOff(aNumPtr);
|
||||
for (int i = 0; i < aNumPtr; ++i) {
|
||||
aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1)));
|
||||
}
|
||||
|
||||
Value offB0 = isBRow ? mul(threadIdN, nContig) : _0;
|
||||
Value offB1 = isBRow ? _0 : mul(threadIdN, nContig);
|
||||
SmallVector<Value> bOff(bNumPtr);
|
||||
for (int i = 0; i < bNumPtr; ++i) {
|
||||
bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1)));
|
||||
}
|
||||
|
||||
auto aSmem = getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter);
|
||||
auto bSmem = getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter);
|
||||
|
||||
Type f32PtrTy = ptr_ty(f32_ty);
|
||||
SmallVector<Value> aPtrs(aNumPtr);
|
||||
for (int i = 0; i < aNumPtr; ++i)
|
||||
aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]);
|
||||
|
||||
SmallVector<Value> bPtrs(bNumPtr);
|
||||
for (int i = 0; i < bNumPtr; ++i)
|
||||
bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]);
|
||||
|
||||
ValueTable has, hbs;
|
||||
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
|
||||
SmallVector<Value> ret = cc;
|
||||
// is this compatible with blocked layout?
|
||||
|
||||
for (unsigned k = 0; k < NK; k++) {
|
||||
int z = 0;
|
||||
for (unsigned i = 0; i < cShape[order[1]]; i += shapePerCTA[order[1]])
|
||||
for (unsigned j = 0; j < cShape[order[0]]; j += shapePerCTA[order[0]])
|
||||
for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii)
|
||||
for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) {
|
||||
unsigned m = order[0] == 1 ? i : j;
|
||||
unsigned n = order[0] == 1 ? j : i;
|
||||
unsigned mm = order[0] == 1 ? ii : jj;
|
||||
unsigned nn = order[0] == 1 ? jj : ii;
|
||||
if (!has.count({m + mm, k})) {
|
||||
Value pa = gep(f32PtrTy, aPtrs[0],
|
||||
i32_val((m + mm) * strideAM + k * strideAK));
|
||||
Value va = load(pa);
|
||||
has[{m + mm, k}] = va;
|
||||
}
|
||||
if (!hbs.count({n + nn, k})) {
|
||||
Value pb = gep(f32PtrTy, bPtrs[0],
|
||||
i32_val((n + nn) * strideBN + k * strideBK));
|
||||
Value vb = load(pb);
|
||||
hbs[{n + nn, k}] = vb;
|
||||
}
|
||||
|
||||
ret[z] = rewriter.create<LLVM::FMulAddOp>(loc, has[{m + mm, k}],
|
||||
hbs[{n + nn, k}], ret[z]);
|
||||
++z;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = getStructFromElements(
|
||||
loc, ret, rewriter,
|
||||
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// ====================== mma codegen end ============================
|
||||
|
||||
Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout,
|
||||
|
@@ -576,6 +576,14 @@ public:
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
||||
// for FMA, should retain the blocked layout.
|
||||
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
||||
!dotOp.allowTF32())
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
@@ -629,4 +637,4 @@ public:
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
||||
|
@@ -169,3 +169,65 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [
|
||||
[32, 32, 16, 4, 32, 32, 16],
|
||||
[32, 16, 16, 4, 32, 32, 16],
|
||||
[128, 8, 8, 4, 32, 32, 16],
|
||||
[127, 41, 43, 4, 32, 32, 16],
|
||||
])
|
||||
def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K):
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
|
||||
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
|
||||
a = tl.load(a_ptrs, a_mask)
|
||||
b = tl.load(b_ptrs, b_mask)
|
||||
# NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering
|
||||
accumulator += tl.dot(a, b, allow_tf32=False)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
offs_k += BLOCK_SIZE_K
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, c_mask)
|
||||
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
matmul_kernel[grid](a, b, c,
|
||||
M, N, K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K)
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.testing.assert_close(c, golden)
|
||||
|
@@ -811,3 +811,22 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
|
||||
// CHECK: llvm.intr.fmuladd
|
||||
%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked>
|
||||
%30 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x1x!tt.ptr<f32>, #blocked>
|
||||
%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr<f32>, #blocked>) -> tensor<32x32x!tt.ptr<f32>, #blocked>
|
||||
tt.store %36, %28 : tensor<32x32xf32, #blocked>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user