[FRONTEND] Add missing rfloordiv (#598)
* [FRONTEND] Add missing rfloordiv * fix tests
This commit is contained in:
@@ -1363,6 +1363,20 @@ def test_constexpr_shape():
|
|||||||
kernel[(1,)](x_tri)
|
kernel[(1,)](x_tri)
|
||||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||||
|
|
||||||
|
|
||||||
|
def test_constexpr_scalar_shape():
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(X, s):
|
||||||
|
off = tl.arange(0, 256)
|
||||||
|
val = off % (256 // s)
|
||||||
|
tl.store(X + off, val)
|
||||||
|
|
||||||
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||||
|
kernel[(1,)](x_tri, 32)
|
||||||
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||||
|
|
||||||
|
|
||||||
# -------------
|
# -------------
|
||||||
# test if
|
# test if
|
||||||
# -------------
|
# -------------
|
||||||
|
@@ -472,6 +472,11 @@ class tensor:
|
|||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
return semantic.floordiv(self, other, _builder)
|
return semantic.floordiv(self, other, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def __rfloordiv__(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
|
return semantic.floordiv(other, self, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __mod__(self, other, _builder=None):
|
def __mod__(self, other, _builder=None):
|
||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
|
Reference in New Issue
Block a user