[triton-mlir][BACKEND] Support masked load/store (#657)
This PR does - fix some bugs to support masked load/store, - refine frontend, and support the `and` and `or` syntax in mask(by extending the BoolOp in python ast.visitor), e.g. `tl.store(..., mask=offset<n and other_conditions)`, - add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by mask), - add more test cases in vecadd.
This commit is contained in:
@@ -145,7 +145,7 @@ std::string PTXInstrExecution::dump() const {
|
||||
if (!pred->repr)
|
||||
os << "@" << pred->dump() << " ";
|
||||
else
|
||||
os << pred->repr(pred->idx);
|
||||
os << pred->repr(pred->idx) << " ";
|
||||
}
|
||||
|
||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||
|
@@ -114,16 +114,16 @@ namespace type = mlir::triton::type;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
// TODO: keep these before we have better debug log utilities
|
||||
// TODO[goostavz]: Remove these methods after we have better debug log utilities
|
||||
template <typename T>
|
||||
void print_array(ArrayRef<T> array, const std::string &str) {
|
||||
std::cout << str << ": ";
|
||||
void printArray(ArrayRef<T> array, const std::string &info) {
|
||||
std::cout << info << ": ";
|
||||
for (const T &e : array)
|
||||
std::cout << e << ",";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
template <typename T> void print_scalar(const T &e, const std::string &str) {
|
||||
std::cout << str << ": " << e << std::endl;
|
||||
template <typename T> void printScalar(const T &e, const std::string &info) {
|
||||
std::cout << info << ": " << e << std::endl;
|
||||
}
|
||||
|
||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||
@@ -808,14 +808,22 @@ struct StoreOpConversion
|
||||
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
||||
assert(ptrElems.size() == valueElems.size());
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
}
|
||||
auto maskOrder = mask.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVectorizeSize(ptr, layout);
|
||||
auto maskAxis = getAxisInfo(mask);
|
||||
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
@@ -1376,13 +1384,15 @@ struct ExtractSliceOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
// A CRTP style of base class.
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class BinaryOpConversionBase
|
||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -1403,13 +1413,16 @@ public:
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto lhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
|
||||
auto rhss =
|
||||
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
|
||||
elems, rewriter);
|
||||
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
|
||||
elems, rewriter);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
|
||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||
rhss[i], loc);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
@@ -1417,6 +1430,123 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct BinaryOpConversion
|
||||
: public BinaryOpConversionBase<SourceOp, DestOp,
|
||||
BinaryOpConversion<SourceOp, DestOp>> {
|
||||
|
||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase<SourceOp, DestOp,
|
||||
BinaryOpConversion<SourceOp, DestOp>>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
// An interface to support variant DestOp builder.
|
||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, Value lhs, Value rhs, Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||
};
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion> {
|
||||
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
Value lhs, Value rhs, Location loc) const {
|
||||
return rewriter.create<LLVM::ICmpOp>(
|
||||
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||
|
||||
static LLVM::ICmpPredicate
|
||||
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
|
||||
switch (predicate) {
|
||||
#define __PRED_ENUM(item__) \
|
||||
case arith::CmpIPredicate::item__: \
|
||||
return LLVM::ICmpPredicate::item__
|
||||
|
||||
__PRED_ENUM(eq);
|
||||
__PRED_ENUM(ne);
|
||||
__PRED_ENUM(sgt);
|
||||
__PRED_ENUM(sge);
|
||||
__PRED_ENUM(slt);
|
||||
__PRED_ENUM(sle);
|
||||
__PRED_ENUM(ugt);
|
||||
__PRED_ENUM(uge);
|
||||
__PRED_ENUM(ult);
|
||||
__PRED_ENUM(ule);
|
||||
|
||||
#undef __PRED_ENUM
|
||||
}
|
||||
return LLVM::ICmpPredicate::eq;
|
||||
}
|
||||
};
|
||||
|
||||
struct CmpFOpConversion
|
||||
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
CmpFOpConversion> {
|
||||
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
Value lhs, Value rhs, Location loc) const {
|
||||
return rewriter.create<LLVM::FCmpOp>(
|
||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||
|
||||
static LLVM::FCmpPredicate
|
||||
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
|
||||
switch (predicate) {
|
||||
#define __PRED_ENUM(item__, item1__) \
|
||||
case arith::CmpFPredicate::item__: \
|
||||
return LLVM::FCmpPredicate::item1__
|
||||
|
||||
__PRED_ENUM(OEQ, oeq);
|
||||
__PRED_ENUM(ONE, one);
|
||||
__PRED_ENUM(OGT, ogt);
|
||||
__PRED_ENUM(OGE, oge);
|
||||
__PRED_ENUM(OLT, olt);
|
||||
__PRED_ENUM(OLE, ole);
|
||||
__PRED_ENUM(ORD, ord);
|
||||
__PRED_ENUM(UEQ, ueq);
|
||||
__PRED_ENUM(UGT, ugt);
|
||||
__PRED_ENUM(ULT, ult);
|
||||
__PRED_ENUM(ULE, ule);
|
||||
__PRED_ENUM(UNE, une);
|
||||
__PRED_ENUM(UNO, uno);
|
||||
__PRED_ENUM(AlwaysTrue, _true);
|
||||
__PRED_ENUM(AlwaysFalse, _false);
|
||||
|
||||
#undef __PRED_ENUM
|
||||
}
|
||||
return LLVM::FCmpPredicate::_true;
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertLayoutOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
||||
public:
|
||||
@@ -3011,6 +3141,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||
benefit);
|
||||
|
||||
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
|
||||
benefit);
|
||||
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
|
Reference in New Issue
Block a user