[PYTHON][TESTS][DOC] Various improvement of the API and code quality:

* Simplified `triton.kernel` API to achieve lower latency:
  > .data_ptr() must now be passed as kernel argument. No more implicit
conversion from torch.tensor
  > compilation options are now constant attributes, i.e., opt.d('VAR')
becomes opt.VAR
  > torch.device must now be passed explicitly to triton.kernel (no
longer inferred from torch.tensor arguments)
* C++ tests moved to `python/tests/`
* C++ tutorial created in `tutorials/`
* Python tutorial created in python/tutorials/
* Version changed to 1.0alpha
* No longer copying C++ headers into the Python package
* added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
Philippe Tillet
2021-01-29 17:27:16 -05:00
parent a5a477c36b
commit 269ebc12e5
63 changed files with 2255 additions and 3883 deletions

View File

@@ -0,0 +1,50 @@
import itertools
import torch
import triton as tt
import pytest
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i*block: (i+1)*block, j*block: (j+1)*block]
return ret
def mask_tensor(x, mask, block, value = 0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value
return ret
@pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\
for at in [False, True]\
for bt in [False, True]\
for block in [16, 32, 64]
]
)
def test_op(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE = torch.float16, Z = 3, H = 2, M = 128, N = 256, K = 384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda')
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda')
shape = {'sdd': (M, N), 'dsd': (a.shape[2], a.shape[3]), 'dds': (b.shape[2], b.shape[3])}[MODE]
layout = torch.randint(2, (H, shape[0]//BLOCK, shape[1]//BLOCK))
# triton result
op = tt.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = sparsify_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
rb = sparsify_tensor(b, layout, BLOCK) if MODE == 'dds' else b
rc = op(ra, rb)
# torch result
ta = mask_tensor(a, layout, BLOCK) if MODE == 'dsd' else a
tb = mask_tensor(b, layout, BLOCK) if MODE == 'dds' else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = mask_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
tc = sparsify_tensor(tc, layout, BLOCK) if MODE == 'sdd' else tc
# compare
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(rc, tc, rtol=rtol, atol=atol)

17
python/tests/test_conv.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
import triton
def test_op():
torch.manual_seed(0)
DTYPE = torch.float16
N, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
pad, stride, = (1, 1), (1, 1)
dilation = (1, 1)
a = torch.rand((N , CI, H, W ), dtype=DTYPE, device='cuda') / CI**.5
b = torch.rand((CI, R , S, CO), dtype=DTYPE, device='cuda') / CI**.5
th_c = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, dilation)
tt_c = triton.ops.conv(a, b, pad, stride)
rtol, atol = {torch.float32: (1e-4, 1e-5),
torch.float16: (1e-2, 1e-3)}[DTYPE]
assert torch.allclose(tt_c, th_c, atol=atol, rtol=rtol)

View File

@@ -0,0 +1,96 @@
import pytest
import itertools
import triton as tt
import torch as th
@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 8, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE),
(128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE)
]
for DTYPE in ['float16']
for AT in [False, True]
for BT in [False, True]
]))
def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE]
th.manual_seed(0)
tt.ops._matmul.kernel = dict()
tt.ops._matmul.TM = [TM]
tt.ops._matmul.TN = [TN]
tt.ops._matmul.TK = [TK]
tt.ops._matmul.num_warps = [NWARP]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = th.matmul(a, b)
tt_c = tt.ops.matmul(a, b)
rtol, atol = {th.float32: (1e-4, 1e-5),
th.float16: (1e-2, 1e-3)}[DTYPE]
assert th.allclose(tt_c, th_c, atol=atol, rtol=rtol)
def do_bench(fn, flops = 0, warmup = 10, rep = 50):
start_event = th.cuda.Event(enable_timing=True)
end_event = th.cuda.Event(enable_timing=True)
ret = fn()
for i in range(warmup):
fn()
th.cuda.synchronize()
start_event.record()
for i in range(rep):
fn()
end_event.record()
th.cuda.synchronize()
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms, flops/time_ms*1e-9, ret
def perf_op(dtype=th.float16, warmup=10, rep=50):
AT, BT = False, False
configs = [(N, N, N) for N in [128, 8192]]
for M, N, K in configs:
a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = a[::,::]
b = b[::,::]
TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep)
print((M, N, K), TH_MS, TT_MS)

View File

@@ -0,0 +1,8 @@
import torch
import triton
def test_op(M = 1024, N = 1024, dtype = torch.float32):
x = torch.randn(M, N, dtype=dtype, device='cuda')
th_y = torch.softmax(x, dim=-1)
tt_y = triton.ops.softmax(x)
assert torch.allclose(tt_y, th_y)