[triton-mlir][BACKEND] Support masked load/store (#657)

This PR does

- fix some bugs to support masked load/store,
- refine frontend, and support the `and` and `or` syntax in mask(by
extending the BoolOp in python ast.visitor), e.g. `tl.store(...,
mask=offset<n and other_conditions)`,
- add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by
mask),
- add more test cases in vecadd.
This commit is contained in:
Yan Chunwei
2022-10-10 13:29:53 +08:00
committed by GitHub
parent ccc5ab6ac9
commit 555f94f9b9
9 changed files with 396 additions and 74 deletions

View File

@@ -145,7 +145,7 @@ std::string PTXInstrExecution::dump() const {
if (!pred->repr)
os << "@" << pred->dump() << " ";
else
os << pred->repr(pred->idx);
os << pred->repr(pred->idx) << " ";
}
std::string instrRepr = strJoin(instr->instrParts, ".");

View File

@@ -114,16 +114,16 @@ namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO: keep these before we have better debug log utilities
// TODO[goostavz]: Remove these methods after we have better debug log utilities
template <typename T>
void print_array(ArrayRef<T> array, const std::string &str) {
std::cout << str << ": ";
void printArray(ArrayRef<T> array, const std::string &info) {
std::cout << info << ": ";
for (const T &e : array)
std::cout << e << ",";
std::cout << std::endl;
}
template <typename T> void print_scalar(const T &e, const std::string &str) {
std::cout << str << ": " << e << std::endl;
template <typename T> void printScalar(const T &e, const std::string &info) {
std::cout << info << ": " << e << std::endl;
}
// FuncOpConversion/FuncOpConversionBase is borrowed from
@@ -808,14 +808,22 @@ struct StoreOpConversion
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
assert(valueElems.size() == maskElems.size());
}
auto maskOrder = mask.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<BlockedEncodingAttr>()
.getOrder();
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
auto maskAxis = getAxisInfo(mask);
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
vec = std::min(vec, maskAlign);
}
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
@@ -1376,13 +1384,15 @@ struct ExtractSliceOpConversion
}
};
template <typename SourceOp, typename DestOp>
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class BinaryOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
@@ -1403,13 +1413,16 @@ public:
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
elems, rewriter);
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
@@ -1417,6 +1430,123 @@ public:
}
};
template <typename SourceOp, typename DestOp>
struct BinaryOpConversion
: public BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>> {
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
};
struct CmpIOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion> {
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::ICmpPredicate
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__) \
case arith::CmpIPredicate::item__: \
return LLVM::ICmpPredicate::item__
__PRED_ENUM(eq);
__PRED_ENUM(ne);
__PRED_ENUM(sgt);
__PRED_ENUM(sge);
__PRED_ENUM(slt);
__PRED_ENUM(sle);
__PRED_ENUM(ugt);
__PRED_ENUM(uge);
__PRED_ENUM(ult);
__PRED_ENUM(ule);
#undef __PRED_ENUM
}
return LLVM::ICmpPredicate::eq;
}
};
struct CmpFOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion> {
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
// An interface to support variant DestOp builder.
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::FCmpPredicate
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__, item1__) \
case arith::CmpFPredicate::item__: \
return LLVM::FCmpPredicate::item1__
__PRED_ENUM(OEQ, oeq);
__PRED_ENUM(ONE, one);
__PRED_ENUM(OGT, ogt);
__PRED_ENUM(OGE, oge);
__PRED_ENUM(OLT, olt);
__PRED_ENUM(OLE, ole);
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
__PRED_ENUM(UNO, uno);
__PRED_ENUM(AlwaysTrue, _true);
__PRED_ENUM(AlwaysFalse, _false);
#undef __PRED_ENUM
}
return LLVM::FCmpPredicate::_true;
}
};
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
@@ -3011,6 +3141,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
benefit);
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);