Init TritonGPU to LLVM dialect conversion (#32)
* add toLLVM pass * update num-warps setting in mlir
This commit is contained in:
@@ -22,6 +22,7 @@ int main(int argc, char **argv) {
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::test::registerTestAlignmentPass();
|
||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES_H
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@@ -24,4 +24,25 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Convert TritonGPU to LLVM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::StandardOpsDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
29
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
29
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
@@ -0,0 +1,29 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
mlir::LLVMTypeConverter &typeConverter;
|
||||
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||
mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -1,2 +1,2 @@
|
||||
# add_subdirectory(TritonGPUToLLVM)
|
||||
add_subdirectory(TritonToTritonGPU)
|
||||
add_subdirectory(TritonGPUToLLVM)
|
||||
|
19
lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Normal file
19
lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
|
||||
DEPENDS
|
||||
TritonConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
TritonGPUTransforms
|
||||
)
|
240
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Normal file
240
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Normal file
@@ -0,0 +1,240 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
|
||||
// The following code are borrowed from mlir project including the following
|
||||
// functions or classes:
|
||||
// - filterFuncAttributes
|
||||
// - ConvertOpToLLVMPattern
|
||||
// - FuncOpConversion
|
||||
//
|
||||
// The code are hidden in the CPP files in MLIR repo, and we can't call them
|
||||
// directly. I found such code snippets are refactored and added to LLVMCommon
|
||||
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
|
||||
// doesn't contain the code.
|
||||
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
|
||||
//
|
||||
// The original code:
|
||||
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
|
||||
// All the rights are reserved by LLVM community.
|
||||
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||
bool filterArgAttrs,
|
||||
SmallVectorImpl<NamedAttribute> &result) {
|
||||
for (const auto &attr : attrs) {
|
||||
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||
attr.getName() == "std.varargs" ||
|
||||
(filterArgAttrs &&
|
||||
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
|
||||
continue;
|
||||
result.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||
protected:
|
||||
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||
// to this legalization pattern.
|
||||
LLVM::LLVMFuncOp
|
||||
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Convert the original function arguments. They are converted using the
|
||||
// LLVMTypeConverter provided to this legalization pattern.
|
||||
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
assert(llvmType);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
// Propagate argument attributes to all converted arguments obtained after
|
||||
// converting a given original argument.
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||
attributes);
|
||||
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
||||
SmallVector<Attribute, 4> newArgAttrs(
|
||||
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
||||
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||
auto mapping = result.getInputMapping(i);
|
||||
assert(mapping.hasValue() &&
|
||||
"unexpected deletion of function argument");
|
||||
for (size_t j = 0; j < mapping->size; ++j)
|
||||
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
||||
}
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
|
||||
rewriter.getArrayAttr(newArgAttrs)));
|
||||
}
|
||||
for (const auto &pair : llvm::enumerate(attributes)) {
|
||||
if (pair.value().getName() == "llvm.linkage") {
|
||||
attributes.erase(attributes.begin() + pair.index());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Create an LLVM function, use external linkage by default until MLIR
|
||||
// functions have linkage.
|
||||
LLVM::Linkage linkage = LLVM::Linkage::External;
|
||||
if (funcOp->hasAttr("llvm.linkage")) {
|
||||
auto attr =
|
||||
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||
if (!attr) {
|
||||
funcOp->emitError()
|
||||
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
||||
return nullptr;
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
|
||||
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
|
||||
: FuncOpConversionBase(converter), NumWarps(numWarps) {
|
||||
mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion(
|
||||
[&](triton::PointerType type) {
|
||||
return convertTritonPointerType(type);
|
||||
});
|
||||
}
|
||||
|
||||
Type convertTritonPointerType(triton::PointerType type) {
|
||||
return LLVM::LLVMPointerType::get(type.getPointeeType(),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
if (!newFuncOp)
|
||||
return failure();
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
auto i32 = IntegerType::get(ctx, 32);
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr("nvvm.maxntidx",
|
||||
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
int NumWarps{-1};
|
||||
};
|
||||
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, OpAdaptor adapter,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
// TODO(Superjomn) add support for non-inline device function
|
||||
if (numArguments > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only kernel function with nothing returned is supported.");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps) {
|
||||
patterns.add<::FuncOpConversion>(typeConverter, numWarps);
|
||||
patterns.add<::ReturnOpConversion>(typeConverter);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
public:
|
||||
ConvertTritonGPUToLLVM() = default;
|
||||
// For manually overwrite the numWarps option
|
||||
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
LLVMTypeConverter typeConverter(context);
|
||||
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
|
||||
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
||||
return std::make_unique<::ConvertTritonGPUToLLVM>();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
@@ -1634,8 +1635,12 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||
})
|
||||
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
.def("add_triton_gpu_verifier_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUVerifier());
|
||||
})
|
||||
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||
});
|
||||
}
|
||||
|
||||
|
9
test/Conversion/tritongpu_to_llvm.mlir
Normal file
9
test/Conversion/tritongpu_to_llvm.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
// CHECK: attributes {nvvm.maxntidx = 96 : i32}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user