[FRONTEND][BACKEND] Fixed various bugs (#819)

- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
This commit is contained in:
Philippe Tillet
2022-10-28 23:34:14 -07:00
committed by GitHub
parent 82834d34f9
commit 7dfab26a39
5 changed files with 74 additions and 52 deletions

View File

@@ -76,6 +76,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
} // namespace
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
@@ -102,6 +103,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
#define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
#define icmp_slt(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
@@ -109,6 +112,7 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define i32_ty rewriter.getIntegerType(32)
#define i8_ty rewriter.getIntegerType(8)
#define f32_ty rewriter.getF32Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
@@ -2024,7 +2028,12 @@ void ConvertLayoutOpConversion::processReplica(
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto llvmElemTy = getTypeConverter()->convertType(type.getElementType());
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
SmallVector<Value> multiDimOffsetFirstElem;
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
@@ -2131,16 +2140,22 @@ void ConvertLayoutOpConversion::processReplica(
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
valVec = insert_element(
vecTy, valVec,
vals[elemId + linearCTAId * accumSizePerThread + v], idx_val(v));
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
vals[elemId + linearCTAId * accumSizePerThread + v] =
extract_element(llvmElemTy, valVec, idx_val(v));
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
if (isInt1)
currVal =
icmp_ne(currVal, rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
}