[CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101)
This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue
This commit is contained in:
committed by
Philippe Tillet
parent
e16bee1a27
commit
840140bf26
@@ -137,8 +137,8 @@ def swish(x):
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4),
|
||||
#triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -202,11 +202,12 @@ def matmul(a, b, activation=None):
|
||||
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
_matmul[grid](
|
||||
pgm = _matmul[grid](
|
||||
a, b, c, M, N, K, \
|
||||
a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
|
||||
ACTIVATION = activation
|
||||
)
|
||||
#print(pgm.asm('ttir'))
|
||||
# return output
|
||||
return c
|
||||
|
||||
@@ -218,13 +219,14 @@ def matmul(a, b, activation=None):
|
||||
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel)
|
||||
|
||||
#torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
c_0 = matmul(a, b, activation=swish)
|
||||
c_1 = torch.nn.SiLU()(torch.matmul(a, b))
|
||||
print(c_0)
|
||||
print(c_1)
|
||||
print(triton.testing.allclose(c_0, c_1))
|
||||
# a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
# b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
# c_0 = matmul(a, b, activation=None)
|
||||
# c_1 = torch.matmul(a, b)
|
||||
# print(c_0)
|
||||
# print(c_1)
|
||||
# print(triton.testing.allclose(c_0, c_1))
|
||||
# exit()
|
||||
|
||||
# %%
|
||||
# Benchmark
|
||||
@@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1))
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name`
|
||||
x_vals=[8192], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
line_vals=['cublas', 'triton'], # possible values for `line_arg``
|
||||
line_names=["cuBLAS", "Triton"], # label name for the lines
|
||||
|
Reference in New Issue
Block a user