[FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8 internally for int1 data to prevent llvm from using vec<i1> which has different semantics) - Fixed semantics of some casts to bool in the frontend
This commit is contained in:
@@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
} // namespace
|
||||
|
||||
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
@@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
||||
#define icmp_slt(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
||||
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
||||
@@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
@@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
|
||||
}
|
||||
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
|
||||
auto elemTy = type.getElementType();
|
||||
bool isInt1 = elemTy.isInteger(1);
|
||||
if (isInt1)
|
||||
elemTy = IntegerType::get(elemTy.getContext(), 8);
|
||||
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
|
||||
|
||||
SmallVector<Value> multiDimOffsetFirstElem;
|
||||
SmallVector<Value> mmaColIdx(2);
|
||||
SmallVector<Value> mmaRowIdx(2);
|
||||
@@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica(
|
||||
if (stNotRd) {
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
valVec = insert_element(
|
||||
vecTy, valVec,
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
|
||||
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
|
||||
if (isInt1)
|
||||
currVal = zext(llvmElemTy, currVal);
|
||||
|
||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||
}
|
||||
store(valVec, ptr);
|
||||
} else {
|
||||
Value valVec = load(ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v] =
|
||||
extract_element(llvmElemTy, valVec, idx_val(v));
|
||||
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
|
||||
if (isInt1)
|
||||
currVal =
|
||||
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
|
||||
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user