[FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and clearly separate the compilation step from the just-in-time caching logic. This should substantially reduce launch overhead.
This commit is contained in:
@@ -11,7 +11,7 @@ from numpy.random import RandomState
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
|
||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
elif (op in ('%', '/') and
|
||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
with pytest.raises(triton.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||
else:
|
||||
@@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
else:
|
||||
numpy_expr = None
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
with pytest.raises(triton.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||
# The CompilationError must have been caused by a C++ exception with this text.
|
||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||
@@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
def catch_compilation_error(kernel):
|
||||
try:
|
||||
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
|
||||
except triton.code_gen.CompilationError as e:
|
||||
except triton.CompilationError as e:
|
||||
np.testing.assert_(True)
|
||||
except BaseException:
|
||||
np.testing.assert_(False)
|
||||
@@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache):
|
||||
assert 'ld.global.cg' not in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
|
||||
def test_vectorization(N):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
@@ -1221,10 +1221,8 @@ def test_vectorization(N):
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||
ptx = pgm.asm["ptx"]
|
||||
if N % 4 == 0:
|
||||
if N % 16 == 0:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
elif N % 2 == 0:
|
||||
assert "ld.global.v2.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
@@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
def cache_hook(*args, **kwargs):
|
||||
nonlocal spec_type
|
||||
spec_type = kwargs["compile"]["arg_types"][0][1]
|
||||
spec_type = kwargs["compile"]["signature"][0]
|
||||
JITFunction.cache_hook = cache_hook
|
||||
|
||||
@triton.jit
|
||||
@@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
if overflow:
|
||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||
with pytest.raises(OverflowError):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
Reference in New Issue
Block a user