* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
8 lines
244 B
Python
8 lines
244 B
Python
import torch
|
|
import triton
|
|
|
|
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
|
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
|
th_y = torch.softmax(x, dim=-1)
|
|
tt_y = triton.ops.softmax(x)
|
|
assert torch.allclose(tt_y, th_y) |