[FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and clearly separate the compilation step from the just-in-time caching logic. This should substantially reduce launch overhead.
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import JITFunction
|
||||
from triton.runtime.jit import JITFunction
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
@@ -99,16 +99,16 @@ def test_specialize(mode):
|
||||
reset_tmp_dir()
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
|
||||
target = {'enable': 5, 'disable': 1}[mode]
|
||||
target = {'enable': 3, 'disable': 1}[mode]
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@@ -120,14 +120,14 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
cache_str = kwargs["repr"]
|
||||
triton.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
triton.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
||||
|
Reference in New Issue
Block a user