[FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet) Integer signedness analysis was moved from C++ to python (@daadaada) Cleaner frontend types (@daadaada) Moved SSA construction to a separate object (@ptillet) Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
@@ -12,7 +11,7 @@ from numpy.random import RandomState
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
@@ -993,11 +992,17 @@ def test_noop(device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
spec_type = None
|
||||
|
||||
def cache_hook(*args, **kwargs):
|
||||
nonlocal spec_type
|
||||
spec_type = kwargs["compile"]["arg_types"][0][1]
|
||||
JITFunction.cache_hook = cache_hook
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
@@ -1006,11 +1011,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
|
||||
assert ir_value_type == value_type
|
||||
JITFunction.cache_hook = None
|
||||
assert spec_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1045,13 +1047,13 @@ def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
|
||||
|
||||
def test_dyn_par(cond=True, device='cuda'):
|
||||
n_pids = 10
|
||||
# pids = torch.arange(n_pids, device=device)
|
||||
# alpha = 2.0
|
||||
# x_ref = pids * alpha
|
||||
x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# cond = torch.tensor([cond], device=device)
|
||||
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
print(x_tri)
|
||||
# triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
# def test_dyn_par(cond=True, device='cuda'):
|
||||
# n_pids = 10
|
||||
# # pids = torch.arange(n_pids, device=device)
|
||||
# # alpha = 2.0
|
||||
# # x_ref = pids * alpha
|
||||
# x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# # cond = torch.tensor([cond], device=device)
|
||||
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
# print(x_tri)
|
||||
# # triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
@@ -102,3 +103,30 @@ def test_specialize(mode):
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
Reference in New Issue
Block a user