From 71c35bcf9c181d8655464b003e3a4f1240ce3f80 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 8 Dec 2022 09:07:01 -0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Mark float to integer in Arithmetic Dialect as legal (#963) --- lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp | 1 + python/tests/test_core.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 42150c362..9bf9ecb05 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -114,6 +114,7 @@ void populateArithmeticPatternsAndLegality( GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern>(typeConverter, context); } diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 68f78272b..abd951055 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -677,6 +677,7 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"): kernel[(1,)](z_tri, x_tri, axis, shape0, shape1) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + def test_atomic_cas(): # 1. make sure that atomic_cas changes the original value (Lock) @triton.jit @@ -742,9 +743,11 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # triton kernel @triton.jit def kernel(X, Z, BITCAST: tl.constexpr): - x = tl.load(X) + x_ptr = X + tl.arange(0, 1) + z_ptr = Z + tl.arange(0, 1) + x = tl.load(x_ptr) z = x.to(Z.dtype.element_ty, bitcast=BITCAST) - tl.store(Z, z) + tl.store(z_ptr, z) dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' # triton result