[Triton-MLIR][BACKEND] Mark float to integer in Arithmetic Dialect as legal (#963)
This commit is contained in:
@@ -114,6 +114,7 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
||||||
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
||||||
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
||||||
|
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
||||||
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -677,6 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
|||||||
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
def test_atomic_cas():
|
def test_atomic_cas():
|
||||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -742,9 +743,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
|||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||||
x = tl.load(X)
|
x_ptr = X + tl.arange(0, 1)
|
||||||
|
z_ptr = Z + tl.arange(0, 1)
|
||||||
|
x = tl.load(x_ptr)
|
||||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||||
tl.store(Z, z)
|
tl.store(z_ptr, z)
|
||||||
|
|
||||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||||
# triton result
|
# triton result
|
||||||
|
Reference in New Issue
Block a user