[FRONTEND][BACKEND] Fixed various bugs (#819)

- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
This commit is contained in:
Philippe Tillet
2022-10-28 23:34:14 -07:00
committed by GitHub
parent 82834d34f9
commit 7dfab26a39
5 changed files with 74 additions and 52 deletions

View File

@@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
} // namespace } // namespace
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__) #define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__) #define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__) #define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr) #define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define icmp_eq(...) \ #define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \ #define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__) #define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
@@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc) #define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__) #define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32) #define i32_ty rewriter.getIntegerType(32)
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type() #define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type) #define vec_ty(type, num) VectorType::get(num, type)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
@@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica(
for (unsigned d = 0; d < rank; ++d) { for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]); numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
} }
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
SmallVector<Value> multiDimOffsetFirstElem; SmallVector<Value> multiDimOffsetFirstElem;
SmallVector<Value> mmaColIdx(2); SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2); SmallVector<Value> mmaRowIdx(2);
@@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica(
if (stNotRd) { if (stNotRd) {
Value valVec = undef(vecTy); Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) { for (unsigned v = 0; v < vec; ++v) {
valVec = insert_element( auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
vecTy, valVec, if (isInt1)
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v)); currVal = zext(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
} }
store(valVec, ptr); store(valVec, ptr);
} else { } else {
Value valVec = load(ptr); Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) { for (unsigned v = 0; v < vec; ++v) {
vals[elemId + linearCTAId * accumSizePerThread + v] = Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
extract_element(llvmElemTy, valVec, idx_val(v)); if (isInt1)
currVal =
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
} }
} }
} }

View File

@@ -170,9 +170,8 @@ void init_triton_ir(py::module &&m) {
.def("replace_all_uses_with", .def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) { [](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue); self.replaceAllUsesWith(newValue);
}) });
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement"); py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::Region>(m, "region") py::class_<mlir::Region>(m, "region")
@@ -660,13 +659,13 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc(); auto loc = self.getUnknownLoc();
// get element type if necessary // get element type if necessary
mlir::Type srcType = src.getType(); mlir::Type srcType = src.getType();
auto srcTensorType = srcType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstType.dyn_cast<mlir::RankedTensorType>();
mlir::Type srcEltType = srcType; mlir::Type srcEltType = srcType;
mlir::Type dstEltType = dstType; mlir::Type dstEltType = dstType;
if (dstType.isa<mlir::RankedTensorType>()) { if (dstTensorType && srcTensorType) {
dstEltType = dstEltType = dstTensorType.getElementType();
dstType.cast<mlir::RankedTensorType>().getElementType(); srcEltType = srcTensorType.getElementType();
srcEltType =
srcType.cast<mlir::RankedTensorType>().getElementType();
} }
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();

View File

@@ -411,41 +411,40 @@ def test_where(dtype):
assert (z == to_numpy(z_tri)).all() assert (z == to_numpy(z_tri)).all()
# TODO: wrong result def test_where_broadcast():
# def test_where_broadcast(): @triton.jit
# @triton.jit def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = tl.load(cond_ptr + yoffsets) mask = tl.load(cond_ptr + yoffsets)
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.) res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# @triton.jit @triton.jit
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = 0 mask = 0
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.) res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# SIZE = 32 SIZE = 32
# dtype = 'float32' dtype = 'float32'
# rs = RandomState(17) rs = RandomState(17)
# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
# mask = numpy_random(SIZE, 'bool', rs=rs) mask = numpy_random(SIZE, 'bool', rs=rs)
# z = np.where(mask, x, 0) z = np.where(mask, x, 0)
# cond_tri = to_triton(mask, device="cuda") cond_tri = to_triton(mask, device="cuda")
# x_tri = to_triton(x, device='cuda', dst_type=dtype) x_tri = to_triton(x, device='cuda', dst_type=dtype)
# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype) z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
# assert (z == to_numpy(z_tri)).all() assert (z == to_numpy(z_tri)).all()
# where_scalar_condition[(1,)](x_tri, z_tri, SIZE) where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
# z = np.where(0, x, 0) z = np.where(0, x, 0)
# assert (z == to_numpy(z_tri)).all() assert (z == to_numpy(z_tri)).all()
# # --------------- # # ---------------
# # test unary ops # # test unary ops
@@ -719,7 +718,7 @@ def test_tuples():
# ('bfloat16', 'float32', False), # ('bfloat16', 'float32', False),
('float32', 'int32', True), ('float32', 'int32', True),
# TODO: # TODO:
# ('float32', 'int1', False), ('float32', 'int1', False),
] + [ ] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [ ] + [

View File

@@ -993,6 +993,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
module = optimize_tritongpu_ir(module, num_stages) module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir": if output == "ttgir":
return module.str() return module.str()
if extern_libs: if extern_libs:
add_external_libs(module, extern_libs) add_external_libs(module, extern_libs)

View File

@@ -649,19 +649,27 @@ def cast(input: tl.tensor,
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
return tl.tensor(builder.create_int_cast(input.handle, if dst_sca_ty.is_bool():
dst_ty.to_ir(builder), sign_extend), ty = input.dtype.to_ir(builder)
dst_ty) _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Float to Int # Float to Int
if src_sca_ty.is_floating() and dst_sca_ty.is_int(): if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# TODO: is this correct?
if dst_sca_ty.is_bool(): if dst_sca_ty.is_bool():
return tl.tensor(builder.create_fp_to_ui(input.handle, ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)), dst_ty.to_ir(builder)),
dst_ty) dst_ty)
else: else:
return tl.tensor(builder.create_fp_to_si(input.handle, return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)), dst_ty.to_ir(builder)),
dst_ty) dst_ty)