[BACKEND] Restored reduction bugfixes
This commit is contained in:
@@ -252,6 +252,7 @@ def matmul_kernel(
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
x = x + 1
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ def matmul(a, b, activation=None):
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
triton_output = matmul(a, b, activation=leaky_relu)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
@@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output):
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
print(matmul_kernel.cache_key)
|
||||
exit()
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
|
Reference in New Issue
Block a user