Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)

This reverts commit 539961072c.
This commit is contained in:
Philippe Tillet
2022-03-24 17:16:50 -07:00
committed by GitHub
parent ea6d1f1b85
commit 76a9ee50a8
19 changed files with 1670 additions and 2044 deletions

View File

@@ -1,4 +1,5 @@
# flake8: noqa: F821,F841
import copy
import itertools
import re
from typing import Optional, Union
@@ -584,6 +585,7 @@ def test_f8_f16_roundtrip():
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
print(f16.dtype, f8_output.dtype)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
@@ -991,6 +993,27 @@ def test_noop(device='cuda'):
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
assert ir_value_type == value_type
@pytest.mark.parametrize(
"value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]

View File

@@ -1,5 +1,4 @@
import os
import re
import shutil
import pytest
@@ -103,30 +102,3 @@ def test_specialize(mode):
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type