[GENERAL] Various improvements:

* Sparse einsum in triton.ops.einsum
* Hacky support for fixed-tile-size atomic-add
* Various bugfixes in parser
This commit is contained in:
Philippe Tillet
2020-10-25 11:55:58 -07:00
parent 444907589d
commit 049ab989b5
16 changed files with 574 additions and 331 deletions

109
python/examples/test.py Normal file
View File

@@ -0,0 +1,109 @@
import triton
import numpy
import torch
import itertools
torch.manual_seed(0)
numpy.random.seed(0)
def to_sparse(expr, data, layout, shape, block):
# shape of result
sparse = None
shape_ret = []
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
shape_ret.append(int(layout.sum()))
if d.isupper():
shape_ret.append(block[d])
else:
shape_ret.append(shape[i])
# iterator
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
# create result
ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device)
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
ret_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def to_dense(expr, data, layout, shape, block):
sparse = None
for i, d in enumerate(expr):
if d.isupper() and sparse is None:
sparse = i
ret = torch.zeros(*shape, dtype=data.dtype, device=data.device)
steps = [block[d] if d.isupper() else 1 for d in expr]
it = [range(0, shape[i], steps[i]) for i in range(len(expr))]
blockid = 0
nzblockid = 0
for curr in itertools.product(*it):
if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]):
blockid = 0
nzblockid = 0
ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))]
data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))]
data_slice.insert(sparse, nzblockid)
if int(layout.view(-1)[blockid]):
ret[ret_slice] = data[data_slice]
nzblockid += 1
blockid += 1
return ret
def test_expr(expr, shape, blocks):
# decompose expr
expr_a, expr_bc = expr.split(",")
expr_b, expr_c = expr_bc.split("->")
# check with argument is sparse
sparse_a = any(x.isupper() for x in expr_a)
sparse_b = any(x.isupper() for x in expr_b)
sparse_c = any(x.isupper() for x in expr_c)
# allocate data
shape_a = [shape[d.lower()] for d in expr_a]
shape_b = [shape[d.lower()] for d in expr_b]
shape_c = [shape[d.lower()] for d in expr_c]
ref_a = torch.rand(*shape_a, device='cuda')
ref_b = torch.rand(*shape_b, device='cuda')
ref_c = torch.zeros(*shape_c, device='cuda')
# layouts
layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()]
layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()]
layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()]
layout_a = torch.randint(0, 2, layout_a, device='cuda')
layout_b = torch.randint(0, 2, layout_b, device='cuda')
layout_c = torch.randint(0, 2, layout_c, device='cuda')
# triton computation
triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a
triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b
layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c}
triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks)
torch.cuda.synchronize()
# reference computation
ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a
ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b
ref_c = torch.einsum(expr.lower(), ref_a, ref_b)
if sparse_c:
ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks)
torch.cuda.synchronize()
print((ref_c - triton_c).abs().max())
# shape characteristics
test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32})
test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32})
test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32})

View File

@@ -56,8 +56,8 @@ class _dot(torch.autograd.Function):
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
int rxm[TM] = ridx * TM + 0 ... TM;
int rxn[TN] = ridy * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
@@ -95,7 +95,7 @@ class _dot(torch.autograd.Function):
if dtype not in _dot.kernel:
defines = {
'TYPE' : dtype,
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [64, 128],
'TN' : [64, 128],
@@ -107,14 +107,12 @@ class _dot(torch.autograd.Function):
# allocate output
M, K = a.shape
K, N = b.shape
c = triton.empty([M,N], dtype=dtype)
c = torch.empty([M,N], dtype=dtype, device=a.device)
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0),
grid=grid, bench=100)
print(2*M*N*K/(time*1e-6)*1e-9)
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
@@ -126,8 +124,10 @@ M, N, K = 2048, 2048, 2048
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
#a[:] = 1
#b[:] = 1
#zc = torch.matmul(a,b)
zc = torch.matmul(a,b)
zc_ = dot(a,b)
#print(torch.allclose(zc, zc_))
print(torch.allclose(zc, zc_))