[Triton-MLIR][BACKEND] Mark float to integer in Arithmetic Dialect as legal (#963)

This commit is contained in:
Keren Zhou
2022-12-08 09:07:01 -08:00
committed by GitHub
parent c7cf9c6a32
commit 71c35bcf9c
2 changed files with 6 additions and 2 deletions

View File

@@ -114,6 +114,7 @@ void populateArithmeticPatternsAndLegality(
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>, GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>, GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>, GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context); GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
} }

View File

@@ -677,6 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) kernel[(1,)](z_tri, x_tri, axis, shape0, shape1)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
def test_atomic_cas(): def test_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock) # 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit @triton.jit
@@ -742,9 +743,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, Z, BITCAST: tl.constexpr): def kernel(X, Z, BITCAST: tl.constexpr):
x = tl.load(X) x_ptr = X + tl.arange(0, 1)
z_ptr = Z + tl.arange(0, 1)
x = tl.load(x_ptr)
z = x.to(Z.dtype.element_ty, bitcast=BITCAST) z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
tl.store(Z, z) tl.store(z_ptr, z)
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
# triton result # triton result