241 lines
8.8 KiB
C++
241 lines
8.8 KiB
C++
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
namespace mlir {
|
|
namespace LLVM {
|
|
|
|
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
|
|
|
} // namespace LLVM
|
|
} // namespace mlir
|
|
|
|
namespace {
|
|
|
|
// The following code are borrowed from mlir project including the following
|
|
// functions or classes:
|
|
// - filterFuncAttributes
|
|
// - ConvertOpToLLVMPattern
|
|
// - FuncOpConversion
|
|
//
|
|
// The code are hidden in the CPP files in MLIR repo, and we can't call them
|
|
// directly. I found such code snippets are refactored and added to LLVMCommon
|
|
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
|
|
// doesn't contain the code.
|
|
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
|
|
//
|
|
// The original code:
|
|
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
|
|
// All the rights are reserved by LLVM community.
|
|
|
|
/// Only retain those attributes that are not constructed by
|
|
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
|
/// attributes.
|
|
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
|
bool filterArgAttrs,
|
|
SmallVectorImpl<NamedAttribute> &result) {
|
|
for (const auto &attr : attrs) {
|
|
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
|
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
|
attr.getName() == "std.varargs" ||
|
|
(filterArgAttrs &&
|
|
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
|
|
continue;
|
|
result.push_back(attr);
|
|
}
|
|
}
|
|
|
|
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
|
protected:
|
|
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
|
|
|
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
|
// to this legalization pattern.
|
|
LLVM::LLVMFuncOp
|
|
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Convert the original function arguments. They are converted using the
|
|
// LLVMTypeConverter provided to this legalization pattern.
|
|
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
|
|
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
|
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
|
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
|
assert(llvmType);
|
|
if (!llvmType)
|
|
return nullptr;
|
|
|
|
// Propagate argument attributes to all converted arguments obtained after
|
|
// converting a given original argument.
|
|
SmallVector<NamedAttribute, 4> attributes;
|
|
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
|
attributes);
|
|
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
|
SmallVector<Attribute, 4> newArgAttrs(
|
|
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
|
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
|
auto mapping = result.getInputMapping(i);
|
|
assert(mapping.hasValue() &&
|
|
"unexpected deletion of function argument");
|
|
for (size_t j = 0; j < mapping->size; ++j)
|
|
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
|
}
|
|
attributes.push_back(
|
|
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
|
|
rewriter.getArrayAttr(newArgAttrs)));
|
|
}
|
|
for (const auto &pair : llvm::enumerate(attributes)) {
|
|
if (pair.value().getName() == "llvm.linkage") {
|
|
attributes.erase(attributes.begin() + pair.index());
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Create an LLVM function, use external linkage by default until MLIR
|
|
// functions have linkage.
|
|
LLVM::Linkage linkage = LLVM::Linkage::External;
|
|
if (funcOp->hasAttr("llvm.linkage")) {
|
|
auto attr =
|
|
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
|
if (!attr) {
|
|
funcOp->emitError()
|
|
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
|
return nullptr;
|
|
}
|
|
linkage = attr.getLinkage();
|
|
}
|
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
|
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
|
/*dsoLocal*/ false, attributes);
|
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
|
newFuncOp.end());
|
|
|
|
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
|
&result)))
|
|
return nullptr;
|
|
|
|
return newFuncOp;
|
|
}
|
|
};
|
|
|
|
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
|
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
|
/// information.
|
|
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
|
struct FuncOpConversion : public FuncOpConversionBase {
|
|
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
|
|
: FuncOpConversionBase(converter), NumWarps(numWarps) {
|
|
mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion(
|
|
[&](triton::PointerType type) {
|
|
return convertTritonPointerType(type);
|
|
});
|
|
}
|
|
|
|
Type convertTritonPointerType(triton::PointerType type) {
|
|
return LLVM::LLVMPointerType::get(type.getPointeeType(),
|
|
type.getAddressSpace());
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
|
if (!newFuncOp)
|
|
return failure();
|
|
|
|
auto ctx = funcOp->getContext();
|
|
auto i32 = IntegerType::get(ctx, 32);
|
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
|
// for `nvvm.annotation` metadata.
|
|
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
|
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
int NumWarps{-1};
|
|
};
|
|
|
|
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, OpAdaptor adapter,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op->getLoc();
|
|
unsigned numArguments = op.getNumOperands();
|
|
|
|
// Currently, Triton kernel function always return nothing.
|
|
// TODO(Superjomn) add support for non-inline device function
|
|
if (numArguments > 0) {
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Only kernel function with nothing returned is supported.");
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns, int numWarps) {
|
|
patterns.add<::FuncOpConversion>(typeConverter, numWarps);
|
|
patterns.add<::ReturnOpConversion>(typeConverter);
|
|
}
|
|
|
|
class ConvertTritonGPUToLLVM
|
|
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
|
public:
|
|
ConvertTritonGPUToLLVM() = default;
|
|
// For manually overwrite the numWarps option
|
|
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp mod = getOperation();
|
|
|
|
LLVMTypeConverter typeConverter(context);
|
|
TritonLLVMConversionTarget target(*context, typeConverter);
|
|
|
|
RewritePatternSet patterns(context);
|
|
// Add arith's patterns to help convert scalar expression to LLVM.
|
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
|
patterns);
|
|
|
|
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
|
|
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
namespace mlir {
|
|
|
|
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
|
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
|
addLegalDialect<LLVM::LLVMDialect>();
|
|
}
|
|
|
|
namespace triton {
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
|
return std::make_unique<::ConvertTritonGPUToLLVM>();
|
|
}
|
|
|
|
} // namespace triton
|
|
} // namespace mlir
|