[Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
ben-zhang-609
2022-10-28 13:26:29 +08:00
committed by GitHub
parent 3b80801dff
commit 3685194456
9 changed files with 616 additions and 60 deletions

View File

@@ -1792,17 +1792,15 @@ struct ExtractSliceOpConversion
}
};
// TODO: rewrite Ternary/Binary/Unary as Elementwise
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class BinaryOpConversionBase
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
@@ -1817,7 +1815,8 @@ public:
auto resultLayout =
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
assert(resultLayout &&
"Unexpected resultLayout in ElementwiseOpConversionBase");
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
@@ -1825,43 +1824,54 @@ public:
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
rewriter);
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
rewriter);
auto operands = getOperands(rewriter, adaptor, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], loc);
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
for (int i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
return operands;
}
};
template <typename SourceOp, typename DestOp>
struct BinaryOpConversion
: public BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>> {
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp, DestOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>>(
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, operands,
adaptor.getAttributes().getValue());
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
};
//
@@ -2015,25 +2025,22 @@ struct UnaryOpConversion
//
struct CmpIOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion> {
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
ValueRange operands, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), operands[0],
operands[1]);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::ICmpPredicate
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
@@ -2059,25 +2066,22 @@ struct CmpIOpConversion
};
struct CmpFOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion> {
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion> {
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
ValueRange operands, Location loc) const {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
operands[1]);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::FCmpPredicate
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
@@ -4081,6 +4085,90 @@ struct InsertSliceAsyncOpConversion
}
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
ExtElemwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
StringRef funcName = op.symbol();
if (funcName.empty())
llvm::errs() << "ExtElemwiseOpConversion";
Type funcType = getFunctionType(elemTy, operands);
LLVM::LLVMFuncOp funcOp =
appendOrGetFuncOp(rewriter, op, funcName, funcType);
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
}
private:
Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
triton::ExtElemwiseOp op,
StringRef funcName, Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
ret.getOperation()->setAttr(
"libname", StringAttr::get(op->getContext(), op.libname()));
ret.getOperation()->setAttr(
"libpath", StringAttr::get(op->getContext(), op.libpath()));
return ret;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
FDivOpConversion> {
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
LLVM::InlineAsmOp, FDivOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder ptxBuilder;
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) {
fdiv.o("full").o("f32");
auto res = ptxBuilder.newOperand("=r");
auto lhs = ptxBuilder.newOperand(operands[0], "r");
auto rhs = ptxBuilder.newOperand(operands[1], "r");
fdiv(res, lhs, rhs);
} else if (64 == bitwidth) {
fdiv.o("rn").o("f64");
auto res = ptxBuilder.newOperand("=l");
auto lhs = ptxBuilder.newOperand(operands[0], "l");
auto rhs = ptxBuilder.newOperand(operands[1], "l");
fdiv(res, lhs, rhs);
} else {
assert(0 && bitwidth && "not supported");
}
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret;
}
};
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -4093,12 +4181,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
@@ -4122,7 +4211,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
@@ -4135,8 +4224,17 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
#undef POPULATE_UNARY_OP
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,