diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 862687f6f..2d7b1d45e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1267,7 +1267,7 @@ def test_arange(start, device='cuda'): # --------------- -@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]]) +@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]]) def test_masked_load(dtype_str, size, size_diff, device='cuda'): dtype = getattr(torch, dtype_str) check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested @@ -1286,18 +1286,18 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'): def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): in_offsets = tl.arange(0, out_size) # Load inputs. - x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1) + x = GENERATE_TEST_HERE # Store output output_offsets = tl.arange(0, out_size) tl.store(out_ptr + output_offsets, x) - _kernel[(1,)](input, output, input_size, output_size) + mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1,)](input, output, input_size, output_size) - reference_out = input - reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device))) + reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) triton.testing.allclose(output, reference_out) -# 'bfloat16': torch.bfloat16, # Testing masked loads with an intermate copy to shared memory run. diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9750c2237..e4a350b96 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -830,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", 'type cache_modifier: str, optional """ # mask, other can be constexpr - if mask is not None: + if _constexpr_to_value(mask) is not None: mask = _to_tensor(mask, _builder) - if other is not None: + if _constexpr_to_value(other) is not None: other = _to_tensor(other, _builder) cache_modifier = _constexpr_to_value(cache_modifier) eviction_policy = _constexpr_to_value(eviction_policy) @@ -856,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None): """ # value can be constexpr value = _to_tensor(value, _builder) - if mask is not None: + if _constexpr_to_value(mask) is not None: mask = _to_tensor(mask, _builder) return semantic.store(pointer, value, mask, _builder)