diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 93063d064..2eadf34a2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1363,6 +1363,20 @@ def test_constexpr_shape(): kernel[(1,)](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + +def test_constexpr_scalar_shape(): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32)) + kernel[(1,)](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + # ------------- # test if # ------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 4197a3333..fdf9063a7 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -472,6 +472,11 @@ class tensor: other = _to_tensor(other, _builder) return semantic.floordiv(self, other, _builder) + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + @builtin def __mod__(self, other, _builder=None): other = _to_tensor(other, _builder)