Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -47,13 +47,37 @@ def mask_tensor(x, mask, block, value=0):
def allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
if x.dtype != y.dtype:
raise RuntimeError(f'{x.dtype} did not match with {x.dtype}')
if x.shape != y.shape:
raise RuntimeError(f'{x.shape} did not match with {y.shape}')
if x.dtype == torch.bool:
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
tol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
return err <= tol
def assert_allclose(x, y, tol=1e-2):
assert x.dtype == y.dtype
assert allclose(x, y, tol)
def random(shape, dtype, device):
if isinstance(shape, int):
shape = (shape, )
if dtype == torch.bool:
return torch.randint(0, 2, shape, dtype=dtype, device=device)
if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.randint(1, 32, shape, dtype=dtype, device=device)
if dtype in [torch.float16, torch.float32, torch.float64]:
return torch.randn(shape, dtype=dtype, device=device)
raise RuntimeError(f'Unknown dtype {dtype}')
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.2, 0.8]):