[FRONTEND] Complete rewrite of the runtime (#644)

This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
This commit is contained in:
Philippe Tillet
2022-09-18 08:51:48 -07:00
committed by GitHub
parent 889d9e34a1
commit 4a77dfb042
17 changed files with 1198 additions and 780 deletions

View File

@@ -11,7 +11,7 @@ from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -273,7 +273,7 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
@@ -311,7 +311,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
with pytest.raises(triton.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
@@ -500,7 +500,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
def catch_compilation_error(kernel):
try:
kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0])
except triton.code_gen.CompilationError as e:
except triton.CompilationError as e:
np.testing.assert_(True)
except BaseException:
np.testing.assert_(False)
@@ -1209,7 +1209,7 @@ def test_load_cache_modifier(cache):
assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
@pytest.mark.parametrize("N", [16, 10, 11, 1024])
def test_vectorization(N):
src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda')
@@ -1221,10 +1221,8 @@ def test_vectorization(N):
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 4 == 0:
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
elif N % 2 == 0:
assert "ld.global.v2.b32" in ptx
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
@@ -1292,7 +1290,7 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["arg_types"][0][1]
spec_type = kwargs["compile"]["signature"][0]
JITFunction.cache_hook = cache_hook
@triton.jit
@@ -1319,7 +1317,7 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
x = torch.tensor([3.14159], device='cuda')
if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
with pytest.raises(OverflowError):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)

View File

@@ -78,10 +78,9 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
triton.autotune(configs, [])(kernel)
kernel.kernel_decorators += decorators[1:]
kernel.configs = configs
# kernel.run = kernel.run.run.run
# get matrix shape
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N

View File

@@ -7,7 +7,7 @@ import torch
import triton
import triton.language as tl
from triton.code_gen import JITFunction
from triton.runtime.jit import JITFunction
tmpdir = ".tmp"
@@ -99,16 +99,16 @@ def test_specialize(mode):
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 5, 'disable': 1}[mode]
target = {'enable': 3, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
triton.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type