[FRONTEND] Semantic analysis refactor (#491)

Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
Philippe Tillet
2022-04-06 16:13:53 -07:00
committed by GitHub
parent 2bed6fc850
commit 9f08ecd684
19 changed files with 2174 additions and 1745 deletions

View File

@@ -1,5 +1,4 @@
# flake8: noqa: F821,F841
import copy
import itertools
import re
from typing import Optional, Union
@@ -12,7 +11,7 @@ from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
from triton.code_gen import JITFunction, TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -993,11 +992,17 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
spec_type = None
def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["arg_types"][0][1]
JITFunction.cache_hook = cache_hook
@triton.jit
def kernel(VALUE, X):
@@ -1006,11 +1011,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
assert ir_value_type == value_type
JITFunction.cache_hook = None
assert spec_type == value_type
@pytest.mark.parametrize(
@@ -1045,13 +1047,13 @@ def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
def test_dyn_par(cond=True, device='cuda'):
n_pids = 10
# pids = torch.arange(n_pids, device=device)
# alpha = 2.0
# x_ref = pids * alpha
x_tri = torch.full((10,), fill_value=-1., device=device)
# cond = torch.tensor([cond], device=device)
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
print(x_tri)
# triton.testing.assert_almost_equal(x_ref, x_tri)
# def test_dyn_par(cond=True, device='cuda'):
# n_pids = 10
# # pids = torch.arange(n_pids, device=device)
# # alpha = 2.0
# # x_ref = pids * alpha
# x_tri = torch.full((10,), fill_value=-1., device=device)
# # cond = torch.tensor([cond], device=device)
# stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
# print(x_tri)
# # triton.testing.assert_almost_equal(x_ref, x_tri)

View File

@@ -1,4 +1,5 @@
import os
import re
import shutil
import pytest
@@ -102,3 +103,30 @@ def test_specialize(mode):
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type