[Frontend] Convert constexpr to value for store and load ops (#1030)

Fixing problem 2 in https://github.com/openai/triton/issues/1017

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Keren Zhou
2023-01-05 14:40:16 -05:00
committed by GitHub
parent 2193bee94e
commit 4023149ee3
2 changed files with 9 additions and 9 deletions

View File

@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
x = GENERATE_TEST_HERE
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
_kernel[(1,)](input, output, input_size, output_size)
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1,)](input, output, input_size, output_size)
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.

View File

@ -830,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
'type cache_modifier: str, optional
"""
# mask, other can be constexpr
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
if other is not None:
if _constexpr_to_value(other) is not None:
other = _to_tensor(other, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
@ -856,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None):
"""
# value can be constexpr
value = _to_tensor(value, _builder)
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)