[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
50
python/tests/test_blocksparse.py
Normal file
50
python/tests/test_blocksparse.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import itertools
|
||||
import torch
|
||||
import triton as tt
|
||||
import pytest
|
||||
|
||||
def sparsify_tensor(x, mask, block):
|
||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
|
||||
return ret
|
||||
|
||||
def mask_tensor(x, mask, block, value = 0):
|
||||
ret = x.clone()
|
||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
|
||||
return ret
|
||||
|
||||
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[
|
||||
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
|
||||
for at in [False, True]\
|
||||
for bt in [False, True]\
|
||||
for block in [16, 32, 64]
|
||||
]
|
||||
)
|
||||
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
|
||||
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
|
||||
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
|
||||
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
|
||||
# triton result
|
||||
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
rc = op(ra, rb)
|
||||
# torch result
|
||||
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
|
||||
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
|
||||
ta = ta.transpose(2, 3) if TRANS_A else ta
|
||||
tb = tb.transpose(2, 3) if TRANS_B else tb
|
||||
tc = torch.matmul(ta, tb)
|
||||
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
|
||||
# compare
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)
|
17
python/tests/test_conv.py
Normal file
17
python/tests/test_conv.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def test_op():
|
||||
torch.manual_seed(0)
|
||||
DTYPE = torch.float16
|
||||
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad, stride, = (1, 1), (1, 1)
|
||||
dilation = (1, 1)
|
||||
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
|
||||
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
|
||||
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
|
||||
tt_c = triton.ops.conv(a, b, pad, stride)
|
||||
rtol, atol = {torch.float32: (1e-4, 1e-5),
|
||||
torch.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
96
python/tests/test_matmul.py
Normal file
96
python/tests/test_matmul.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
import itertools
|
||||
import triton as tt
|
||||
import torch as th
|
||||
|
||||
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
|
||||
# 2 warp
|
||||
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
|
||||
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
|
||||
# 4 warp
|
||||
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
|
||||
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
|
||||
# 8 warp
|
||||
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
|
||||
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
|
||||
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
|
||||
]
|
||||
for DTYPE in ['float16']
|
||||
for AT in [False, True]
|
||||
for BT in [False, True]
|
||||
]))
|
||||
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
|
||||
th.manual_seed(0)
|
||||
tt.ops._matmul.kernel = dict()
|
||||
tt.ops._matmul.TM = [TM]
|
||||
tt.ops._matmul.TN = [TN]
|
||||
tt.ops._matmul.TK = [TK]
|
||||
tt.ops._matmul.num_warps = [NWARP]
|
||||
if M is None: M = TM
|
||||
if N is None: N = TN
|
||||
if K is None: K = TK
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
||||
a = a.t() if AT else a
|
||||
b = b.t() if BT else b
|
||||
th_c = th.matmul(a, b)
|
||||
tt_c = tt.ops.matmul(a, b)
|
||||
rtol, atol = {th.float32: (1e-4, 1e-5),
|
||||
th.float16: (1e-2, 1e-3)}[DTYPE]
|
||||
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
|
||||
start_event = th.cuda.Event(enable_timing=True)
|
||||
end_event = th.cuda.Event(enable_timing=True)
|
||||
ret = fn()
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
th.cuda.synchronize()
|
||||
start_event.record()
|
||||
for i in range(rep):
|
||||
fn()
|
||||
end_event.record()
|
||||
th.cuda.synchronize()
|
||||
time_ms = start_event.elapsed_time(end_event) / rep
|
||||
return time_ms, flops/time_ms*1e-9, ret
|
||||
|
||||
|
||||
def perf_op(dtype=th.float16, warmup=10, rep=50):
|
||||
AT, BT = False, False
|
||||
configs = [(N, N, N) for N in [128, 8192]]
|
||||
for M, N, K in configs:
|
||||
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
||||
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
||||
if AT: a = a.t()
|
||||
if BT: b = b.t()
|
||||
a = a[::,::]
|
||||
b = b[::,::]
|
||||
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
|
||||
print((M, N, K), TH_MS, TT_MS)
|
8
python/tests/test_softmax.py
Normal file
8
python/tests/test_softmax.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
def test_op(M = 1024, N = 1024, dtype = torch.float32):
|
||||
x = torch.randn(M, N, dtype=dtype, device='cuda')
|
||||
th_y = torch.softmax(x, dim=-1)
|
||||
tt_y = triton.ops.softmax(x)
|
||||
assert torch.allclose(tt_y, th_y)
|
Reference in New Issue
Block a user