[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)

This massively simplifies implementation of `reassociate` and also fixes
a bunch of bug. The pass could still be improved, but can already be used
to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
Philippe Tillet
2021-05-07 17:54:37 -04:00
committed by Philippe Tillet
parent e16bee1a27
commit 840140bf26
12 changed files with 204 additions and 667 deletions

View File

@@ -137,8 +137,8 @@ def swish(x):
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
],
key=['M', 'N', 'K'],
)
@@ -202,11 +202,12 @@ def matmul(a, b, activation=None):
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
_matmul[grid](
pgm = _matmul[grid](
a, b, c, M, N, K, \
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
ACTIVATION = activation
)
#print(pgm.asm('ttir'))
# return output
return c
@@ -218,13 +219,14 @@ def matmul(a, b, activation=None):
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
#torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=swish)
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
# c_0 = matmul(a, b, activation=None)
# c_1 = torch.matmul(a, b)
# print(c_0)
# print(c_1)
# print(triton.testing.allclose(c_0, c_1))
# exit()
# %%
# Benchmark
@@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
x_vals=[8192], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['cublas', 'triton'], # possible values for `line_arg``
line_names=["cuBLAS", "Triton"], # label name for the lines