From 7dfab26a3973c26b083f4e6802e34b7bf7b4c4c8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 28 Oct 2022 23:34:14 -0700 Subject: [PATCH] [FRONTEND][BACKEND] Fixed various bugs (#819) - Fixed bugs on layout conversions for int1 data (we should use int8 internally for int1 data to prevent llvm from using vec which has different semantics) - Fixed semantics of some casts to bool in the frontend --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 27 ++++++-- python/src/triton.cc | 13 ++-- python/tests/test_core.py | 65 +++++++++---------- python/triton/compiler.py | 1 + python/triton/language/semantic.py | 20 ++++-- 5 files changed, 74 insertions(+), 52 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 22c5ab720..d1fcdb92f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, } // namespace // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +#define zext(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) @@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define store(val, ptr) rewriter.create(loc, val, ptr) #define icmp_eq(...) \ rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_ne(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) #define icmp_slt(...) \ rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) #define select(...) rewriter.create(loc, __VA_ARGS__) @@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) +#define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() #define vec_ty(type, num) VectorType::get(num, type) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) @@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica( for (unsigned d = 0; d < rank; ++d) { numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); } - auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); + auto elemTy = type.getElementType(); + bool isInt1 = elemTy.isInteger(1); + if (isInt1) + elemTy = IntegerType::get(elemTy.getContext(), 8); + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + SmallVector multiDimOffsetFirstElem; SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); @@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica( if (stNotRd) { Value valVec = undef(vecTy); for (unsigned v = 0; v < vec; ++v) { - valVec = insert_element( - vecTy, valVec, - vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v)); + auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; + if (isInt1) + currVal = zext(llvmElemTy, currVal); + + valVec = insert_element(vecTy, valVec, currVal, idx_val(v)); } store(valVec, ptr); } else { Value valVec = load(ptr); for (unsigned v = 0; v < vec; ++v) { - vals[elemId + linearCTAId * accumSizePerThread + v] = - extract_element(llvmElemTy, valVec, idx_val(v)); + Value currVal = extract_element(llvmElemTy, valVec, idx_val(v)); + if (isInt1) + currVal = + icmp_ne(currVal, rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); + vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; } } } diff --git a/python/src/triton.cc b/python/src/triton.cc index 243283006..66ae425e6 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -170,9 +170,8 @@ void init_triton_ir(py::module &&m) { .def("replace_all_uses_with", [](mlir::Value &self, mlir::Value &newValue) { self.replaceAllUsesWith(newValue); - }) + }); - ; py::class_(m, "block_arguement"); py::class_(m, "region") @@ -660,13 +659,13 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); // get element type if necessary mlir::Type srcType = src.getType(); + auto srcTensorType = srcType.dyn_cast(); + auto dstTensorType = dstType.dyn_cast(); mlir::Type srcEltType = srcType; mlir::Type dstEltType = dstType; - if (dstType.isa()) { - dstEltType = - dstType.cast().getElementType(); - srcEltType = - srcType.cast().getElementType(); + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); } unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 910520fc5..56b1f36db 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -411,41 +411,40 @@ def test_where(dtype): assert (z == to_numpy(z_tri)).all() -# TODO: wrong result -# def test_where_broadcast(): -# @triton.jit -# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): -# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] -# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] +def test_where_broadcast(): + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] -# mask = tl.load(cond_ptr + yoffsets) -# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) -# res = tl.where(mask, vals, 0.) -# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) -# @triton.jit -# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): -# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] -# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] -# mask = 0 -# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) -# res = tl.where(mask, vals, 0.) -# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = 0 + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) -# SIZE = 32 -# dtype = 'float32' -# rs = RandomState(17) -# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) -# mask = numpy_random(SIZE, 'bool', rs=rs) -# z = np.where(mask, x, 0) -# cond_tri = to_triton(mask, device="cuda") -# x_tri = to_triton(x, device='cuda', dst_type=dtype) -# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype) -# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) -# assert (z == to_numpy(z_tri)).all() -# where_scalar_condition[(1,)](x_tri, z_tri, SIZE) -# z = np.where(0, x, 0) -# assert (z == to_numpy(z_tri)).all() + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device="cuda") + x_tri = to_triton(x, device='cuda', dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype) + where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1,)](x_tri, z_tri, SIZE) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() # # --------------- # # test unary ops @@ -719,7 +718,7 @@ def test_tuples(): # ('bfloat16', 'float32', False), ('float32', 'int32', True), # TODO: - # ('float32', 'int1', False), + ('float32', 'int1', False), ] + [ (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] ] + [ diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 35512fabe..f684ff691 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -993,6 +993,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat module = optimize_tritongpu_ir(module, num_stages) if output == "ttgir": return module.str() + if extern_libs: add_external_libs(module, extern_libs) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 878e07a8f..748470957 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -649,19 +649,27 @@ def cast(input: tl.tensor, if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - return tl.tensor(builder.create_int_cast(input.handle, - dst_ty.to_ir(builder), sign_extend), - dst_ty) + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, + dst_ty.to_ir(builder), sign_extend), + dst_ty) # Float to Int if src_sca_ty.is_floating() and dst_sca_ty.is_int(): - # TODO: is this correct? if dst_sca_ty.is_bool(): - return tl.tensor(builder.create_fp_to_ui(input.handle, + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) else: - return tl.tensor(builder.create_fp_to_si(input.handle, + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)