[Triton-MLIR] Minor fixes related with scf/swizzling support (#791)
1, Disable static loop unrolling in the frontend by default; 2, A minor fix in axisAnalysis in order to support scf; 3, A minor fix in TritonGPUToLLVM to support swizzling.
This commit is contained in:
@@ -40,7 +40,8 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
|||||||
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
if (TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||||
rank = ty.getRank();
|
rank = ty.getRank();
|
||||||
int divHint = 1;
|
int divHint = 1;
|
||||||
if (BlockArgument blockArg = value.dyn_cast<BlockArgument>()) {
|
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
|
||||||
|
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
||||||
Operation *op = blockArg.getOwner()->getParentOp();
|
Operation *op = blockArg.getOwner()->getParentOp();
|
||||||
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
if (FuncOp fun = dyn_cast<FuncOp>(op)) {
|
||||||
Attribute attr =
|
Attribute attr =
|
||||||
|
@@ -1867,8 +1867,8 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads;
|
||||||
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
auto multiDimIdxInNanoTile = getMultiDimIndex<unsigned>(
|
||||||
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
|
linearIdxInNanoTile, srcBlockedLayout.getSizePerThread());
|
||||||
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
|
||||||
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec;
|
||||||
|
multiDimIdxInNanoTile[inOrd[0]] /= minVec;
|
||||||
unsigned wordVecIdx =
|
unsigned wordVecIdx =
|
||||||
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
|
getLinearIndex<unsigned>(multiDimIdxInNanoTile, wordsInEachRep);
|
||||||
wordVecs[wordVecIdx] =
|
wordVecs[wordVecIdx] =
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
# import pytest
|
import pytest
|
||||||
# import torch
|
import torch
|
||||||
# from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_no_scf_kernel(
|
||||||
a_ptr, b_ptr, c_ptr,
|
a_ptr, b_ptr, c_ptr,
|
||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
@@ -30,23 +30,72 @@ def matmul_kernel(
|
|||||||
# TODO: num_warps could only be 4 for now
|
# TODO: num_warps could only be 4 for now
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||||
# [128, 256, 32, 4],
|
[128, 256, 32, 4],
|
||||||
# [256, 128, 16, 4],
|
[256, 128, 16, 4],
|
||||||
# [128, 16, 32, 4],
|
[128, 16, 32, 4],
|
||||||
# [32, 128, 64, 4],
|
[32, 128, 64, 4],
|
||||||
|
])
|
||||||
|
def test_gemm_no_scf(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
||||||
|
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||||
|
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||||
|
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||||
|
grid = lambda META: (1, )
|
||||||
|
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||||
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||||
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||||
|
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||||
|
M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
||||||
|
num_warps=NUM_WARPS)
|
||||||
|
golden = torch.matmul(a, b)
|
||||||
|
torch.set_printoptions(profile="full")
|
||||||
|
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def matmul_kernel(
|
||||||
|
a_ptr, b_ptr, c_ptr,
|
||||||
|
stride_am, stride_ak,
|
||||||
|
stride_bk, stride_bn,
|
||||||
|
stride_cm, stride_cn,
|
||||||
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
offs_m = tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
for k in range(0, K, BLOCK_SIZE_K):
|
||||||
|
a = tl.load(a_ptrs)
|
||||||
|
b = tl.load(b_ptrs)
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||||
|
tl.store(c_ptrs, accumulator)
|
||||||
|
|
||||||
|
# TODO: DotConversion in TritonGPUToLLVM cannot support non-splat C for the moment
|
||||||
|
# @pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||||
|
# [128, 256, 128, 4, 128, 256, 32],
|
||||||
|
# # [256, 128, 64, 4, 256, 128, 16],
|
||||||
|
# # [128, 16, 128, 4, 128, 16, 32],
|
||||||
|
# # [32, 128, 256, 4, 32, 128, 64],
|
||||||
# ])
|
# ])
|
||||||
# def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
# def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||||
# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
# a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||||
# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
# b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||||
# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
# c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||||
# grid = lambda META: (1, )
|
# grid = lambda META: (1, )
|
||||||
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
# matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||||
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
# stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||||
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
# stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||||
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
# stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||||
# M=SIZE_M, N=SIZE_N, K=SIZE_K,
|
# M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||||
# num_warps=NUM_WARPS)
|
# BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
# golden = torch.matmul(a, b)
|
# num_warps=NUM_WARPS)
|
||||||
# torch.set_printoptions(profile="full")
|
# golden = torch.matmul(a, b)
|
||||||
# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
# torch.set_printoptions(profile="full")
|
||||||
|
# assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||||
|
@@ -558,7 +558,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
raise RuntimeError('Only `range` iterator currently supported')
|
raise RuntimeError('Only `range` iterator currently supported')
|
||||||
# static for loops: all iterator arguments are constexpr
|
# static for loops: all iterator arguments are constexpr
|
||||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||||
|
is_static = False
|
||||||
|
if static_unrolling:
|
||||||
|
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||||
if is_static:
|
if is_static:
|
||||||
iter_args = [arg.value for arg in iter_args]
|
iter_args = [arg.value for arg in iter_args]
|
||||||
range = iterator(*iter_args)
|
range = iterator(*iter_args)
|
||||||
|
Reference in New Issue
Block a user