[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -1,5 +1,6 @@
import torch
import triton
import triton.language as tl
import copy
import pytest
import ast
@@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
@@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, **meta):
off = triton.arange(0, meta['SIZE'])
x = triton.load(X + off)
y = triton.load(Y + off)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
triton.store(Z + off, z)
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
@@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
stride = 1
for i in reversed(range(rank)):
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
stride *= shape[i]
return f"{name} + {' + '.join(offsets)}"
@@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
@triton.jit
def kernel(Z, X, **meta):
SIZE = meta['SIZE']
m = triton.arange(0, SIZE)
n = triton.arange(0, SIZE)
x = triton.load(X_PTR_EXPR)
m = tl.arange(0, SIZE)
n = tl.arange(0, SIZE)
x = tl.load(X_PTR_EXPR)
z = GENERATE_TEST_HERE
triton.store(Z_PTR_EXPR, z)
tl.store(Z_PTR_EXPR, z)
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),