[OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834)
This commit is contained in:
@@ -157,15 +157,6 @@ import triton.language as tl
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -318,13 +309,13 @@ else:
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
8192
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider', # argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],
|
||||
line_vals=['cublas', 'triton'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
@@ -336,18 +327,9 @@ def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
|
||||
if provider == 'cublas + relu':
|
||||
torch_relu = torch.nn.ReLU(inplace=True)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch_relu(torch.matmul(a, b))
|
||||
)
|
||||
if provider == 'triton + relu':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, activation=leaky_relu)
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
Reference in New Issue
Block a user