[FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8 internally for int1 data to prevent llvm from using vec<i1> which has different semantics) - Fixed semantics of some casts to bool in the frontend
This commit is contained in:
@@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||||
|
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||||
@@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||||
#define icmp_eq(...) \
|
#define icmp_eq(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||||
|
#define icmp_ne(...) \
|
||||||
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||||
#define icmp_slt(...) \
|
#define icmp_slt(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||||
@@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
|||||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||||
#define i32_ty rewriter.getIntegerType(32)
|
#define i32_ty rewriter.getIntegerType(32)
|
||||||
|
#define i8_ty rewriter.getIntegerType(8)
|
||||||
#define f32_ty rewriter.getF32Type()
|
#define f32_ty rewriter.getF32Type()
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||||
@@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
for (unsigned d = 0; d < rank; ++d) {
|
for (unsigned d = 0; d < rank; ++d) {
|
||||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||||
}
|
}
|
||||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
auto elemTy = type.getElementType();
|
||||||
|
bool isInt1 = elemTy.isInteger(1);
|
||||||
|
if (isInt1)
|
||||||
|
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
||||||
|
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||||
|
|
||||||
SmallVector<Value> multiDimOffsetFirstElem;
|
SmallVector<Value> multiDimOffsetFirstElem;
|
||||||
SmallVector<Value> mmaColIdx(2);
|
SmallVector<Value> mmaColIdx(2);
|
||||||
SmallVector<Value> mmaRowIdx(2);
|
SmallVector<Value> mmaRowIdx(2);
|
||||||
@@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica(
|
|||||||
if (stNotRd) {
|
if (stNotRd) {
|
||||||
Value valVec = undef(vecTy);
|
Value valVec = undef(vecTy);
|
||||||
for (unsigned v = 0; v < vec; ++v) {
|
for (unsigned v = 0; v < vec; ++v) {
|
||||||
valVec = insert_element(
|
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
|
||||||
vecTy, valVec,
|
if (isInt1)
|
||||||
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
|
currVal = zext(llvmElemTy, currVal);
|
||||||
|
|
||||||
|
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||||
}
|
}
|
||||||
store(valVec, ptr);
|
store(valVec, ptr);
|
||||||
} else {
|
} else {
|
||||||
Value valVec = load(ptr);
|
Value valVec = load(ptr);
|
||||||
for (unsigned v = 0; v < vec; ++v) {
|
for (unsigned v = 0; v < vec; ++v) {
|
||||||
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
|
||||||
extract_element(llvmElemTy, valVec, idx_val(v));
|
if (isInt1)
|
||||||
|
currVal =
|
||||||
|
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
|
||||||
|
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -170,9 +170,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("replace_all_uses_with",
|
.def("replace_all_uses_with",
|
||||||
[](mlir::Value &self, mlir::Value &newValue) {
|
[](mlir::Value &self, mlir::Value &newValue) {
|
||||||
self.replaceAllUsesWith(newValue);
|
self.replaceAllUsesWith(newValue);
|
||||||
})
|
});
|
||||||
|
|
||||||
;
|
|
||||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
||||||
|
|
||||||
py::class_<mlir::Region>(m, "region")
|
py::class_<mlir::Region>(m, "region")
|
||||||
@@ -660,13 +659,13 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
// get element type if necessary
|
// get element type if necessary
|
||||||
mlir::Type srcType = src.getType();
|
mlir::Type srcType = src.getType();
|
||||||
|
auto srcTensorType = srcType.dyn_cast<mlir::RankedTensorType>();
|
||||||
|
auto dstTensorType = dstType.dyn_cast<mlir::RankedTensorType>();
|
||||||
mlir::Type srcEltType = srcType;
|
mlir::Type srcEltType = srcType;
|
||||||
mlir::Type dstEltType = dstType;
|
mlir::Type dstEltType = dstType;
|
||||||
if (dstType.isa<mlir::RankedTensorType>()) {
|
if (dstTensorType && srcTensorType) {
|
||||||
dstEltType =
|
dstEltType = dstTensorType.getElementType();
|
||||||
dstType.cast<mlir::RankedTensorType>().getElementType();
|
srcEltType = srcTensorType.getElementType();
|
||||||
srcEltType =
|
|
||||||
srcType.cast<mlir::RankedTensorType>().getElementType();
|
|
||||||
}
|
}
|
||||||
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
|
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
|
||||||
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
|
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
|
||||||
|
@@ -411,41 +411,40 @@ def test_where(dtype):
|
|||||||
assert (z == to_numpy(z_tri)).all()
|
assert (z == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
|
|
||||||
# TODO: wrong result
|
def test_where_broadcast():
|
||||||
# def test_where_broadcast():
|
@triton.jit
|
||||||
# @triton.jit
|
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
# def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||||
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||||
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
|
||||||
|
|
||||||
# mask = tl.load(cond_ptr + yoffsets)
|
mask = tl.load(cond_ptr + yoffsets)
|
||||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||||
# res = tl.where(mask, vals, 0.)
|
res = tl.where(mask, vals, 0.)
|
||||||
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
# xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
|
||||||
# yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
|
||||||
# mask = 0
|
mask = 0
|
||||||
# vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
|
||||||
# res = tl.where(mask, vals, 0.)
|
res = tl.where(mask, vals, 0.)
|
||||||
# tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
|
||||||
|
|
||||||
# SIZE = 32
|
SIZE = 32
|
||||||
# dtype = 'float32'
|
dtype = 'float32'
|
||||||
# rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
# x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
|
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
|
||||||
# mask = numpy_random(SIZE, 'bool', rs=rs)
|
mask = numpy_random(SIZE, 'bool', rs=rs)
|
||||||
# z = np.where(mask, x, 0)
|
z = np.where(mask, x, 0)
|
||||||
# cond_tri = to_triton(mask, device="cuda")
|
cond_tri = to_triton(mask, device="cuda")
|
||||||
# x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||||
# z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
|
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||||
# where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
|
where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE)
|
||||||
# assert (z == to_numpy(z_tri)).all()
|
assert (z == to_numpy(z_tri)).all()
|
||||||
# where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
|
where_scalar_condition[(1,)](x_tri, z_tri, SIZE)
|
||||||
# z = np.where(0, x, 0)
|
z = np.where(0, x, 0)
|
||||||
# assert (z == to_numpy(z_tri)).all()
|
assert (z == to_numpy(z_tri)).all()
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test unary ops
|
# # test unary ops
|
||||||
@@ -719,7 +718,7 @@ def test_tuples():
|
|||||||
# ('bfloat16', 'float32', False),
|
# ('bfloat16', 'float32', False),
|
||||||
('float32', 'int32', True),
|
('float32', 'int32', True),
|
||||||
# TODO:
|
# TODO:
|
||||||
# ('float32', 'int1', False),
|
('float32', 'int1', False),
|
||||||
] + [
|
] + [
|
||||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||||
] + [
|
] + [
|
||||||
|
@@ -993,6 +993,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
|||||||
module = optimize_tritongpu_ir(module, num_stages)
|
module = optimize_tritongpu_ir(module, num_stages)
|
||||||
if output == "ttgir":
|
if output == "ttgir":
|
||||||
return module.str()
|
return module.str()
|
||||||
|
|
||||||
if extern_libs:
|
if extern_libs:
|
||||||
add_external_libs(module, extern_libs)
|
add_external_libs(module, extern_libs)
|
||||||
|
|
||||||
|
@@ -649,19 +649,27 @@ def cast(input: tl.tensor,
|
|||||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||||
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
||||||
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
||||||
return tl.tensor(builder.create_int_cast(input.handle,
|
if dst_sca_ty.is_bool():
|
||||||
dst_ty.to_ir(builder), sign_extend),
|
ty = input.dtype.to_ir(builder)
|
||||||
dst_ty)
|
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||||
|
return not_equal(input, _0, builder)
|
||||||
|
else:
|
||||||
|
return tl.tensor(builder.create_int_cast(input.handle,
|
||||||
|
dst_ty.to_ir(builder), sign_extend),
|
||||||
|
dst_ty)
|
||||||
|
|
||||||
# Float to Int
|
# Float to Int
|
||||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||||
# TODO: is this correct?
|
|
||||||
if dst_sca_ty.is_bool():
|
if dst_sca_ty.is_bool():
|
||||||
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
ty = input.dtype.to_ir(builder)
|
||||||
|
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||||
|
return not_equal(input, _0, builder)
|
||||||
|
elif dst_sca_ty.is_int_signed():
|
||||||
|
return tl.tensor(builder.create_fp_to_si(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
else:
|
else:
|
||||||
return tl.tensor(builder.create_fp_to_si(input.handle,
|
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
||||||
dst_ty.to_ir(builder)),
|
dst_ty.to_ir(builder)),
|
||||||
dst_ty)
|
dst_ty)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user