[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import copy
|
||||
import pytest
|
||||
import ast
|
||||
@@ -37,10 +38,10 @@ def _test_unary(dtype_x, expr, device='cuda'):
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
tl.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
@@ -59,11 +60,11 @@ def _test_binary(dtype_x, dtype_y, expr, device='cuda'):
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y, **meta):
|
||||
off = triton.arange(0, meta['SIZE'])
|
||||
x = triton.load(X + off)
|
||||
y = triton.load(Y + off)
|
||||
off = tl.arange(0, meta['SIZE'])
|
||||
x = tl.load(X + off)
|
||||
y = tl.load(Y + off)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z + off, z)
|
||||
tl.store(Z + off, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
|
||||
# inputs
|
||||
@@ -144,7 +145,7 @@ def make_ptr_str(name, shape):
|
||||
stride = 1
|
||||
for i in reversed(range(rank)):
|
||||
idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)])
|
||||
offsets += [f'triton.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||
offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}']
|
||||
stride *= shape[i]
|
||||
return f"{name} + {' + '.join(offsets)}"
|
||||
|
||||
@@ -164,11 +165,11 @@ def test_index1d(expr, device='cuda'):
|
||||
@triton.jit
|
||||
def kernel(Z, X, **meta):
|
||||
SIZE = meta['SIZE']
|
||||
m = triton.arange(0, SIZE)
|
||||
n = triton.arange(0, SIZE)
|
||||
x = triton.load(X_PTR_EXPR)
|
||||
m = tl.arange(0, SIZE)
|
||||
n = tl.arange(0, SIZE)
|
||||
x = tl.load(X_PTR_EXPR)
|
||||
z = GENERATE_TEST_HERE
|
||||
triton.store(Z_PTR_EXPR, z)
|
||||
tl.store(Z_PTR_EXPR, z)
|
||||
|
||||
to_replace = {
|
||||
'X_PTR_EXPR': make_ptr_str('X', shape_x),
|
Reference in New Issue
Block a user