Init TritonGPU to LLVM dialect conversion (#32)

* add toLLVM pass

* update num-warps setting in mlir
This commit is contained in:
Yan Chunwei
2022-08-04 10:15:45 +08:00
committed by GitHub
parent 3236642e8f
commit b988bae813
9 changed files with 328 additions and 3 deletions

View File

@@ -22,6 +22,7 @@ int main(int argc, char **argv) {
mlir::registerTritonGPUPasses();
mlir::test::registerTestAlignmentPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
namespace mlir {

View File

@@ -24,4 +24,25 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
];
}
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
#endif

View File

@@ -0,0 +1,29 @@
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
mlir::LLVMTypeConverter &typeConverter;
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,2 +1,2 @@
# add_subdirectory(TritonGPUToLLVM)
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonGPUToLLVM)

View File

@@ -0,0 +1,19 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
DEPENDS
TritonConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
TritonIR
TritonGPUIR
TritonGPUTransforms
)

View File

@@ -0,0 +1,240 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
using namespace mlir::triton;
namespace mlir {
namespace LLVM {
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
} // namespace LLVM
} // namespace mlir
namespace {
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
// - ConvertOpToLLVMPattern
// - FuncOpConversion
//
// The code are hidden in the CPP files in MLIR repo, and we can't call them
// directly. I found such code snippets are refactored and added to LLVMCommon
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
// doesn't contain the code.
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
//
// The original code:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
// All the rights are reserved by LLVM community.
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs &&
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
continue;
result.push_back(attr);
}
}
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
assert(llvmType);
if (!llvmType)
return nullptr;
// Propagate argument attributes to all converted arguments obtained after
// converting a given original argument.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
attributes);
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping.hasValue() &&
"unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
attributes.push_back(
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {
attributes.erase(attributes.begin() + pair.index());
break;
}
}
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
if (funcOp->hasAttr("llvm.linkage")) {
auto attr =
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
if (!attr) {
funcOp->emitError()
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
return nullptr;
}
linkage = attr.getLinkage();
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
return newFuncOp;
}
};
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
: FuncOpConversionBase(converter), NumWarps(numWarps) {
mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion(
[&](triton::PointerType type) {
return convertTritonPointerType(type);
});
}
Type convertTritonPointerType(triton::PointerType type) {
return LLVM::LLVMPointerType::get(type.getPointeeType(),
type.getAddressSpace());
}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
return failure();
auto ctx = funcOp->getContext();
auto i32 = IntegerType::get(ctx, 32);
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntidx",
rewriter.getIntegerAttr(i32, 32 * NumWarps));
rewriter.eraseOp(funcOp);
return success();
}
private:
int NumWarps{-1};
};
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adapter,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
// TODO(Superjomn) add support for non-inline device function
if (numArguments > 0) {
return rewriter.notifyMatchFailure(
op, "Only kernel function with nothing returned is supported.");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
};
} // namespace
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps) {
patterns.add<::FuncOpConversion>(typeConverter, numWarps);
patterns.add<::ReturnOpConversion>(typeConverter);
}
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
ConvertTritonGPUToLLVM() = default;
// For manually overwrite the numWarps option
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
LLVMTypeConverter typeConverter(context);
TritonLLVMConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
// Add arith's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
};
namespace mlir {
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
addLegalDialect<LLVM::LLVMDialect>();
}
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
return std::make_unique<::ConvertTritonGPUToLLVM>();
}
} // namespace triton
} // namespace mlir

View File

@@ -11,6 +11,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
@@ -1634,8 +1635,12 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCombineOpsPass());
})
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUVerifier());
.def("add_triton_gpu_verifier_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUVerifier());
})
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
});
}

View File

@@ -0,0 +1,9 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
// CHECK: attributes {nvvm.maxntidx = 96 : i32}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}