8 lines
244 B
Python
8 lines
244 B
Python
![]() |
import torch
|
||
|
import triton
|
||
|
|
||
|
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
||
|
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
||
|
th_y = torch.softmax(x, dim=-1)
|
||
|
tt_y = triton.ops.softmax(x)
|
||
|
assert torch.allclose(tt_y, th_y)
|