[Frontend] Convert constexpr to value for store and load ops (#1030)
Fixing problem 2 in https://github.com/openai/triton/issues/1017 Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
|
||||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||||
dtype = getattr(torch, dtype_str)
|
dtype = getattr(torch, dtype_str)
|
||||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||||
@@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
|||||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||||
in_offsets = tl.arange(0, out_size)
|
in_offsets = tl.arange(0, out_size)
|
||||||
# Load inputs.
|
# Load inputs.
|
||||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
x = GENERATE_TEST_HERE
|
||||||
# Store output
|
# Store output
|
||||||
output_offsets = tl.arange(0, out_size)
|
output_offsets = tl.arange(0, out_size)
|
||||||
tl.store(out_ptr + output_offsets, x)
|
tl.store(out_ptr + output_offsets, x)
|
||||||
|
|
||||||
_kernel[(1,)](input, output, input_size, output_size)
|
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
|
||||||
|
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
|
||||||
|
kernel[(1,)](input, output, input_size, output_size)
|
||||||
|
|
||||||
reference_out = input
|
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
|
||||||
triton.testing.allclose(output, reference_out)
|
triton.testing.allclose(output, reference_out)
|
||||||
|
|
||||||
# 'bfloat16': torch.bfloat16,
|
|
||||||
# Testing masked loads with an intermate copy to shared memory run.
|
# Testing masked loads with an intermate copy to shared memory run.
|
||||||
|
|
||||||
|
|
||||||
|
@@ -830,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
|
|||||||
'type cache_modifier: str, optional
|
'type cache_modifier: str, optional
|
||||||
"""
|
"""
|
||||||
# mask, other can be constexpr
|
# mask, other can be constexpr
|
||||||
if mask is not None:
|
if _constexpr_to_value(mask) is not None:
|
||||||
mask = _to_tensor(mask, _builder)
|
mask = _to_tensor(mask, _builder)
|
||||||
if other is not None:
|
if _constexpr_to_value(other) is not None:
|
||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
cache_modifier = _constexpr_to_value(cache_modifier)
|
cache_modifier = _constexpr_to_value(cache_modifier)
|
||||||
eviction_policy = _constexpr_to_value(eviction_policy)
|
eviction_policy = _constexpr_to_value(eviction_policy)
|
||||||
@@ -856,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None):
|
|||||||
"""
|
"""
|
||||||
# value can be constexpr
|
# value can be constexpr
|
||||||
value = _to_tensor(value, _builder)
|
value = _to_tensor(value, _builder)
|
||||||
if mask is not None:
|
if _constexpr_to_value(mask) is not None:
|
||||||
mask = _to_tensor(mask, _builder)
|
mask = _to_tensor(mask, _builder)
|
||||||
return semantic.store(pointer, value, mask, _builder)
|
return semantic.store(pointer, value, mask, _builder)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user