[BACKEND] Restored reduction bugfixes

This commit is contained in:
Philippe Tillet
2022-06-03 11:38:52 -07:00
parent a60374a597
commit 8876e53206
11 changed files with 173 additions and 65 deletions

View File

@@ -252,6 +252,7 @@ def matmul_kernel(
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)
@@ -296,7 +297,7 @@ def matmul(a, b, activation=None):
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b, activation=None)
triton_output = matmul(a, b, activation=leaky_relu)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
@@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output):
else:
print("❌ Triton and Torch differ")
print(matmul_kernel.cache_key)
exit()
# %%
# Benchmark
# --------------