[FRONTEND][BACKEND] Fixed various bugs (#819)

- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
This commit is contained in:
Philippe Tillet
2022-10-28 23:34:14 -07:00
committed by GitHub
parent 82834d34f9
commit 7dfab26a39
5 changed files with 74 additions and 52 deletions

View File

@@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
} // namespace
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
@@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
@@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica(
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
SmallVector<Value> multiDimOffsetFirstElem;
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
@@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica(
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
valVec = insert_element(
vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
vals[elemId + linearCTAId * accumSizePerThread + v] =
extract_element(llvmElemTy, valVec, idx_val(v));
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
if (isInt1)
currVal =
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
}

View File

@@ -170,9 +170,8 @@ void init_triton_ir(py::module &&m) {
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
});
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::Region>(m, "region")
@@ -660,13 +659,13 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
// get element type if necessary
mlir::Type srcType = src.getType();
auto srcTensorType = srcType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstType.dyn_cast<mlir::RankedTensorType>();
mlir::Type srcEltType = srcType;
mlir::Type dstEltType = dstType;
if (dstType.isa<mlir::RankedTensorType>()) {
dstEltType =
dstType.cast<mlir::RankedTensorType>().getElementType();
srcEltType =
srcType.cast<mlir::RankedTensorType>().getElementType();
if (dstTensorType && srcTensorType) {
dstEltType = dstTensorType.getElementType();
srcEltType = srcTensorType.getElementType();
}
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();

View File

@@ -411,41 +411,40 @@ def test_where(dtype):
assert (z == to_numpy(z_tri)).all()
# TODO: wrong result
# def test_where_broadcast():
# @triton.jit
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
def test_where_broadcast():
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = tl.load(cond_ptr + yoffsets)
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# @triton.jit
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
# mask = 0
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
# res = tl.where(mask, vals, 0.)
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = 0
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
# SIZE = 32
# dtype = 'float32'
# rs = RandomState(17)
# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
# mask = numpy_random(SIZE, 'bool', rs=rs)
# z = np.where(mask, x, 0)
# cond_tri = to_triton(mask, device="cuda")
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
# assert (z == to_numpy(z_tri)).all()
# where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
# z = np.where(0, x, 0)
# assert (z == to_numpy(z_tri)).all()
SIZE = 32
dtype = 'float32'
rs = RandomState(17)
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
mask = numpy_random(SIZE, 'bool', rs=rs)
z = np.where(mask, x, 0)
cond_tri = to_triton(mask, device="cuda")
x_tri = to_triton(x, device='cuda', dst_type=dtype)
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
assert (z == to_numpy(z_tri)).all()
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
z = np.where(0, x, 0)
assert (z == to_numpy(z_tri)).all()
# # ---------------
# # test unary ops
@@ -719,7 +718,7 @@ def test_tuples():
# ('bfloat16', 'float32', False),
('float32', 'int32', True),
# TODO:
# ('float32', 'int1', False),
('float32', 'int1', False),
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [

View File

@@ -993,6 +993,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
if extern_libs:
add_external_libs(module, extern_libs)

View File

@@ -649,19 +649,27 @@ def cast(input: tl.tensor,
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
if dst_sca_ty.is_bool():
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Float to Int
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# TODO: is this correct?
if dst_sca_ty.is_bool():
return tl.tensor(builder.create_fp_to_ui(input.handle,
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
else:
return tl.tensor(builder.create_fp_to_si(input.handle,
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)