[Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -1792,17 +1792,15 @@ struct ExtractSliceOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: rewrite Ternary/Binary/Unary as Elementwise
|
||||
|
||||
// A CRTP style of base class.
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class BinaryOpConversionBase
|
||||
class ElementwiseOpConversionBase
|
||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -1817,7 +1815,8 @@ public:
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion");
|
||||
assert(resultLayout &&
|
||||
"Unexpected resultLayout in ElementwiseOpConversionBase");
|
||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
@@ -1825,43 +1824,54 @@ public:
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
|
||||
rewriter);
|
||||
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
|
||||
rewriter);
|
||||
auto operands = getOperands(rewriter, adaptor, elems, loc);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||
rhss[i], loc);
|
||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||
operands[i], loc);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
|
||||
protected:
|
||||
SmallVector<SmallVector<Value>>
|
||||
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
|
||||
const unsigned elems, Location loc) const {
|
||||
SmallVector<SmallVector<Value>> operands(elems);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter);
|
||||
for (int i = 0; i < elems; ++i) {
|
||||
operands[i].push_back(sub_operands[i]);
|
||||
}
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
struct BinaryOpConversion
|
||||
: public BinaryOpConversionBase<SourceOp, DestOp,
|
||||
BinaryOpConversion<SourceOp, DestOp>> {
|
||||
struct ElementwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<
|
||||
SourceOp, DestOp, ElementwiseOpConversion<SourceOp, DestOp>> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<SourceOp, DestOp,
|
||||
ElementwiseOpConversion<SourceOp, DestOp>>;
|
||||
using Base::Base;
|
||||
using OpAdaptor = typename Base::OpAdaptor;
|
||||
|
||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase<SourceOp, DestOp,
|
||||
BinaryOpConversion<SourceOp, DestOp>>(
|
||||
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ElementwiseOpConversionBase<SourceOp, DestOp, ElementwiseOpConversion>(
|
||||
typeConverter, benefit) {}
|
||||
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
// An interface to support variant DestOp builder.
|
||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||
Type elemTy, Value lhs, Value rhs, Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
|
||||
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return rewriter.create<DestOp>(loc, elemTy, operands,
|
||||
adaptor.getAttributes().getValue());
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||
};
|
||||
|
||||
//
|
||||
@@ -2015,25 +2025,22 @@ struct UnaryOpConversion
|
||||
//
|
||||
|
||||
struct CmpIOpConversion
|
||||
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion> {
|
||||
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||
CmpIOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
|
||||
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
Value lhs, Value rhs, Location loc) const {
|
||||
ValueRange operands, Location loc) const {
|
||||
return rewriter.create<LLVM::ICmpOp>(
|
||||
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), operands[0],
|
||||
operands[1]);
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||
|
||||
static LLVM::ICmpPredicate
|
||||
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
|
||||
switch (predicate) {
|
||||
@@ -2059,25 +2066,22 @@ struct CmpIOpConversion
|
||||
};
|
||||
|
||||
struct CmpFOpConversion
|
||||
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
CmpFOpConversion> {
|
||||
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
CmpFOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||
CmpFOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
// An interface to support variant DestOp builder.
|
||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
|
||||
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
Value lhs, Value rhs, Location loc) const {
|
||||
ValueRange operands, Location loc) const {
|
||||
return rewriter.create<LLVM::FCmpOp>(
|
||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), operands[0],
|
||||
operands[1]);
|
||||
}
|
||||
|
||||
// Get the left operand of the op.
|
||||
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||
// Get the right operand of the op.
|
||||
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||
|
||||
static LLVM::FCmpPredicate
|
||||
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
|
||||
switch (predicate) {
|
||||
@@ -4081,6 +4085,90 @@ struct InsertSliceAsyncOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtElemwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<
|
||||
triton::ExtElemwiseOp, LLVM::LLVMFuncOp, ExtElemwiseOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<triton::ExtElemwiseOp, LLVM::LLVMFuncOp,
|
||||
ExtElemwiseOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
StringRef funcName = op.symbol();
|
||||
if (funcName.empty())
|
||||
llvm::errs() << "ExtElemwiseOpConversion";
|
||||
|
||||
Type funcType = getFunctionType(elemTy, operands);
|
||||
LLVM::LLVMFuncOp funcOp =
|
||||
appendOrGetFuncOp(rewriter, op, funcName, funcType);
|
||||
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
|
||||
}
|
||||
|
||||
private:
|
||||
Type getFunctionType(Type resultType, ValueRange operands) const {
|
||||
SmallVector<Type> operandTypes(operands.getTypes());
|
||||
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
|
||||
}
|
||||
|
||||
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
|
||||
triton::ExtElemwiseOp op,
|
||||
StringRef funcName, Type funcType) const {
|
||||
using LLVM::LLVMFuncOp;
|
||||
|
||||
auto funcAttr = StringAttr::get(op->getContext(), funcName);
|
||||
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
|
||||
if (funcOp)
|
||||
return cast<LLVMFuncOp>(*funcOp);
|
||||
|
||||
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
|
||||
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
||||
ret.getOperation()->setAttr(
|
||||
"libname", StringAttr::get(op->getContext(), op.libname()));
|
||||
ret.getOperation()->setAttr(
|
||||
"libpath", StringAttr::get(op->getContext(), op.libpath()));
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct FDivOpConversion
|
||||
: ElementwiseOpConversionBase<mlir::arith::DivFOp, LLVM::InlineAsmOp,
|
||||
FDivOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<mlir::arith::DivFOp,
|
||||
LLVM::InlineAsmOp, FDivOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
|
||||
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
if (32 == bitwidth) {
|
||||
fdiv.o("full").o("f32");
|
||||
auto res = ptxBuilder.newOperand("=r");
|
||||
auto lhs = ptxBuilder.newOperand(operands[0], "r");
|
||||
auto rhs = ptxBuilder.newOperand(operands[1], "r");
|
||||
fdiv(res, lhs, rhs);
|
||||
} else if (64 == bitwidth) {
|
||||
fdiv.o("rn").o("f64");
|
||||
auto res = ptxBuilder.newOperand("=l");
|
||||
auto lhs = ptxBuilder.newOperand(operands[0], "l");
|
||||
auto rhs = ptxBuilder.newOperand(operands[1], "l");
|
||||
fdiv(res, lhs, rhs);
|
||||
} else {
|
||||
assert(0 && bitwidth && "not supported");
|
||||
}
|
||||
|
||||
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
@@ -4093,12 +4181,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<TernaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp);
|
||||
#undef POPULATE_TERNARY_OP
|
||||
|
||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<BinaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
|
||||
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
|
||||
POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp)
|
||||
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
|
||||
@@ -4122,7 +4211,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<UnaryOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||
@@ -4135,8 +4224,17 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
||||
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
|
||||
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
|
||||
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
|
||||
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
|
||||
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
|
@@ -16,6 +16,9 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
@@ -148,13 +151,80 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> extern_libs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
module.walk([&](LLVM::LLVMFuncOp func) {
|
||||
if (func.isExternal())
|
||||
funcs.push_back(func);
|
||||
});
|
||||
|
||||
for (auto &func : funcs) {
|
||||
if (func.getOperation()->hasAttr("libname")) {
|
||||
auto name =
|
||||
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
|
||||
auto path =
|
||||
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
|
||||
if (name) {
|
||||
std::string lib_name = name.str();
|
||||
extern_libs[lib_name] = path.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
|
||||
auto dict = module.getOperation()
|
||||
->getAttr("triton_gpu.externs")
|
||||
.dyn_cast<DictionaryAttr>();
|
||||
for (auto &attr : dict) {
|
||||
extern_libs[attr.getName().strref().trim().str()] =
|
||||
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
|
||||
}
|
||||
}
|
||||
|
||||
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llvm::SMDiagnostic err;
|
||||
for (auto &lib : extern_libs) {
|
||||
auto ext_mod = llvm::parseIRFile(lib.second, err, *llvmContext);
|
||||
if (!ext_mod) {
|
||||
llvm::errs() << "Failed to load extern lib " << lib.first;
|
||||
return nullptr;
|
||||
}
|
||||
ext_mod->setTargetTriple(llvmir->getTargetTriple());
|
||||
ext_mod->setDataLayout(llvmir->getDataLayout());
|
||||
|
||||
if (llvm::Linker::linkModules(*llvmir, std::move(ext_mod))) {
|
||||
llvm::errs() << "Failed to link extern lib " << lib.first;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return llvmir;
|
||||
}
|
||||
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths) {
|
||||
if (names.empty() || names.size() != paths.size())
|
||||
return;
|
||||
|
||||
llvm::SmallVector<NamedAttribute, 2> attrs;
|
||||
|
||||
for (size_t i = 0; i < names.size(); ++i) {
|
||||
auto name = StringAttr::get(module->getContext(), names[i]);
|
||||
auto path = StringAttr::get(module->getContext(), paths[i]);
|
||||
NamedAttribute attr(name, path);
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
|
||||
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
|
||||
module.getOperation()->setAttr("triton_gpu.externs", dict);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
Reference in New Issue
Block a user