From c4726333bf76ec8420b8879dedaccea3a0e36baf Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Fri, 21 Oct 2022 11:46:28 +0800 Subject: [PATCH] [Triton-MLIR] Minor fixes related with scf/swizzling support (#791) 1, Disable static loop unrolling in the frontend by default; 2, A minor fix in axisAnalysis in order to support scf; 3, A minor fix in TritonGPUToLLVM to support swizzling. --- lib/Analysis/AxisInfo.cpp | 3 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- python/tests/test_gemm.py | 95 ++++++++++++++----- python/triton/compiler.py | 5 +- 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 7c205fe0c..d0296d5ab 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -40,7 +40,8 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { if (TensorType ty = value.getType().dyn_cast()) rank = ty.getRank(); int divHint = 1; - if (BlockArgument blockArg = value.dyn_cast()) { + BlockArgument blockArg = value.dyn_cast(); + if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); if (FuncOp fun = dyn_cast(op)) { Attribute attr = diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 56499f3f8..58a59c477 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1867,8 +1867,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; auto multiDimIdxInNanoTile = getMultiDimIndex( linearIdxInNanoTile, srcBlockedLayout.getSizePerThread()); - multiDimIdxInNanoTile[inOrd[0]] /= minVec; unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; + multiDimIdxInNanoTile[inOrd[0]] /= minVec; unsigned wordVecIdx = getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep); wordVecs[wordVecIdx] = diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 6b559f7ec..4f1ff5fdc 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -1,13 +1,13 @@ -# import pytest -# import torch -# from torch.testing import assert_close +import pytest +import torch +from torch.testing import assert_close import triton import triton.language as tl @triton.jit -def matmul_kernel( +def matmul_no_scf_kernel( a_ptr, b_ptr, c_ptr, stride_am, stride_ak, stride_bk, stride_bn, @@ -30,23 +30,72 @@ def matmul_kernel( # TODO: num_warps could only be 4 for now -# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ -# [128, 256, 32, 4], -# [256, 128, 16, 4], -# [128, 16, 32, 4], -# [32, 128, 64, 4], +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [ + [128, 256, 32, 4], + [256, 128, 16, 4], + [128, 16, 32, 4], + [32, 128, 64, 4], +]) +def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): + a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) + b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) + c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) + grid = lambda META: (1, ) + matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + M=SIZE_M, N=SIZE_N, K=SIZE_K, + num_warps=NUM_WARPS) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, accumulator) + +# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment +# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ +# [128, 256, 128, 4, 128, 256, 32], +# # [256, 128, 64, 4, 256, 128, 16], +# # [128, 16, 128, 4, 128, 16, 32], +# # [32, 128, 256, 4, 32, 128, 64], # ]) -# def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): -# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) -# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) -# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) -# grid = lambda META: (1, ) -# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, -# stride_am=a.stride(0), stride_ak=a.stride(1), -# stride_bk=b.stride(0), stride_bn=b.stride(1), -# stride_cm=c.stride(0), stride_cn=c.stride(1), -# M=SIZE_M, N=SIZE_N, K=SIZE_K, -# num_warps=NUM_WARPS) -# golden = torch.matmul(a, b) -# torch.set_printoptions(profile="full") -# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) +# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): +# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) +# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) +# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) +# grid = lambda META: (1, ) +# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, +# stride_am=a.stride(0), stride_ak=a.stride(1), +# stride_bk=b.stride(0), stride_bn=b.stride(1), +# stride_cm=c.stride(0), stride_cn=c.stride(1), +# M=a.shape[0], N=b.shape[1], K=a.shape[1], +# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, +# num_warps=NUM_WARPS) +# golden = torch.matmul(a, b) +# torch.set_printoptions(profile="full") +# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 51f7ee8fd..355bfc605 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -558,7 +558,10 @@ class CodeGenerator(ast.NodeVisitor): raise RuntimeError('Only `range` iterator currently supported') # static for loops: all iterator arguments are constexpr iter_args = [self.visit(arg) for arg in node.iter.args] - is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) + static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False) + is_static = False + if static_unrolling: + is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) if is_static: iter_args = [arg.value for arg in iter_args] range = iterator(*iter_args)