Init TritonGPU to LLVM dialect conversion (#32)
* add toLLVM pass * update num-warps setting in mlir
This commit is contained in:
@@ -22,6 +22,7 @@ int main(int argc, char **argv) {
|
|||||||
mlir::registerTritonGPUPasses();
|
mlir::registerTritonGPUPasses();
|
||||||
mlir::test::registerTestAlignmentPass();
|
mlir::test::registerTestAlignmentPass();
|
||||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||||
|
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
||||||
|
|
||||||
// TODO: register Triton & TritonGPU passes
|
// TODO: register Triton & TritonGPU passes
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#ifndef TRITON_CONVERSION_PASSES_H
|
#ifndef TRITON_CONVERSION_PASSES_H
|
||||||
#define TRITON_CONVERSION_PASSES_H
|
#define TRITON_CONVERSION_PASSES_H
|
||||||
|
|
||||||
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@@ -24,4 +24,25 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
|
||||||
|
let summary = "Convert TritonGPU to LLVM";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
||||||
|
|
||||||
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
|
"mlir::StandardOpsDialect",
|
||||||
|
"mlir::scf::SCFDialect",
|
||||||
|
"mlir::triton::TritonDialect",
|
||||||
|
"mlir::triton::gpu::TritonGPUDialect"];
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"numWarps", "num-warps",
|
||||||
|
"int32_t", /*default*/"4",
|
||||||
|
"number of warps">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
29
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
29
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||||
|
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||||
|
|
||||||
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class ModuleOp;
|
||||||
|
template <typename T> class OperationPass;
|
||||||
|
|
||||||
|
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||||
|
mlir::LLVMTypeConverter &typeConverter;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||||
|
mlir::LLVMTypeConverter &typeConverter);
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif
|
@@ -1,2 +1,2 @@
|
|||||||
# add_subdirectory(TritonGPUToLLVM)
|
|
||||||
add_subdirectory(TritonToTritonGPU)
|
add_subdirectory(TritonToTritonGPU)
|
||||||
|
add_subdirectory(TritonGPUToLLVM)
|
||||||
|
19
lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Normal file
19
lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
add_mlir_conversion_library(TritonGPUToLLVM
|
||||||
|
TritonGPUToLLVM.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TritonConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
TritonIR
|
||||||
|
TritonGPUIR
|
||||||
|
TritonGPUTransforms
|
||||||
|
)
|
240
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Normal file
240
lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||||
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::triton;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace LLVM {
|
||||||
|
|
||||||
|
static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||||
|
|
||||||
|
} // namespace LLVM
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// The following code are borrowed from mlir project including the following
|
||||||
|
// functions or classes:
|
||||||
|
// - filterFuncAttributes
|
||||||
|
// - ConvertOpToLLVMPattern
|
||||||
|
// - FuncOpConversion
|
||||||
|
//
|
||||||
|
// The code are hidden in the CPP files in MLIR repo, and we can't call them
|
||||||
|
// directly. I found such code snippets are refactored and added to LLVMCommon
|
||||||
|
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
|
||||||
|
// doesn't contain the code.
|
||||||
|
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
|
||||||
|
//
|
||||||
|
// The original code:
|
||||||
|
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
|
||||||
|
// All the rights are reserved by LLVM community.
|
||||||
|
|
||||||
|
/// Only retain those attributes that are not constructed by
|
||||||
|
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||||
|
/// attributes.
|
||||||
|
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
||||||
|
bool filterArgAttrs,
|
||||||
|
SmallVectorImpl<NamedAttribute> &result) {
|
||||||
|
for (const auto &attr : attrs) {
|
||||||
|
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
|
||||||
|
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
|
||||||
|
attr.getName() == "std.varargs" ||
|
||||||
|
(filterArgAttrs &&
|
||||||
|
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
|
||||||
|
continue;
|
||||||
|
result.push_back(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
||||||
|
protected:
|
||||||
|
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
||||||
|
// to this legalization pattern.
|
||||||
|
LLVM::LLVMFuncOp
|
||||||
|
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// Convert the original function arguments. They are converted using the
|
||||||
|
// LLVMTypeConverter provided to this legalization pattern.
|
||||||
|
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
|
||||||
|
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||||
|
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||||
|
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||||
|
assert(llvmType);
|
||||||
|
if (!llvmType)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Propagate argument attributes to all converted arguments obtained after
|
||||||
|
// converting a given original argument.
|
||||||
|
SmallVector<NamedAttribute, 4> attributes;
|
||||||
|
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
||||||
|
attributes);
|
||||||
|
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
||||||
|
SmallVector<Attribute, 4> newArgAttrs(
|
||||||
|
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
||||||
|
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
||||||
|
auto mapping = result.getInputMapping(i);
|
||||||
|
assert(mapping.hasValue() &&
|
||||||
|
"unexpected deletion of function argument");
|
||||||
|
for (size_t j = 0; j < mapping->size; ++j)
|
||||||
|
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
||||||
|
}
|
||||||
|
attributes.push_back(
|
||||||
|
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
|
||||||
|
rewriter.getArrayAttr(newArgAttrs)));
|
||||||
|
}
|
||||||
|
for (const auto &pair : llvm::enumerate(attributes)) {
|
||||||
|
if (pair.value().getName() == "llvm.linkage") {
|
||||||
|
attributes.erase(attributes.begin() + pair.index());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an LLVM function, use external linkage by default until MLIR
|
||||||
|
// functions have linkage.
|
||||||
|
LLVM::Linkage linkage = LLVM::Linkage::External;
|
||||||
|
if (funcOp->hasAttr("llvm.linkage")) {
|
||||||
|
auto attr =
|
||||||
|
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
|
||||||
|
if (!attr) {
|
||||||
|
funcOp->emitError()
|
||||||
|
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
linkage = attr.getLinkage();
|
||||||
|
}
|
||||||
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
|
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||||
|
/*dsoLocal*/ false, attributes);
|
||||||
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||||
|
newFuncOp.end());
|
||||||
|
|
||||||
|
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
||||||
|
&result)))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
return newFuncOp;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||||
|
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||||
|
/// information.
|
||||||
|
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||||
|
struct FuncOpConversion : public FuncOpConversionBase {
|
||||||
|
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
|
||||||
|
: FuncOpConversionBase(converter), NumWarps(numWarps) {
|
||||||
|
mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion(
|
||||||
|
[&](triton::PointerType type) {
|
||||||
|
return convertTritonPointerType(type);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Type convertTritonPointerType(triton::PointerType type) {
|
||||||
|
return LLVM::LLVMPointerType::get(type.getPointeeType(),
|
||||||
|
type.getAddressSpace());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||||
|
if (!newFuncOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto ctx = funcOp->getContext();
|
||||||
|
auto i32 = IntegerType::get(ctx, 32);
|
||||||
|
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||||
|
// for `nvvm.annotation` metadata.
|
||||||
|
newFuncOp->setAttr("nvvm.maxntidx",
|
||||||
|
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
||||||
|
|
||||||
|
rewriter.eraseOp(funcOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int NumWarps{-1};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||||
|
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ReturnOp op, OpAdaptor adapter,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
unsigned numArguments = op.getNumOperands();
|
||||||
|
|
||||||
|
// Currently, Triton kernel function always return nothing.
|
||||||
|
// TODO(Superjomn) add support for non-inline device function
|
||||||
|
if (numArguments > 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only kernel function with nothing returned is supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||||
|
op->getAttrs());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns, int numWarps) {
|
||||||
|
patterns.add<::FuncOpConversion>(typeConverter, numWarps);
|
||||||
|
patterns.add<::ReturnOpConversion>(typeConverter);
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConvertTritonGPUToLLVM
|
||||||
|
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||||
|
public:
|
||||||
|
ConvertTritonGPUToLLVM() = default;
|
||||||
|
// For manually overwrite the numWarps option
|
||||||
|
explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ModuleOp mod = getOperation();
|
||||||
|
|
||||||
|
LLVMTypeConverter typeConverter(context);
|
||||||
|
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||||
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
|
patterns);
|
||||||
|
|
||||||
|
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
||||||
|
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
|
||||||
|
: ConversionTarget(ctx), typeConverter(typeConverter) {
|
||||||
|
addLegalDialect<LLVM::LLVMDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass() {
|
||||||
|
return std::make_unique<::ConvertTritonGPUToLLVM>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace triton
|
||||||
|
} // namespace mlir
|
@@ -11,6 +11,7 @@
|
|||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/Triton/IR/Types.h"
|
#include "triton/Dialect/Triton/IR/Types.h"
|
||||||
@@ -1634,8 +1635,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
|
.def("add_triton_gpu_verifier_pass",
|
||||||
self.addPass(mlir::createTritonGPUVerifier());
|
[](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createTritonGPUVerifier());
|
||||||
|
})
|
||||||
|
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
9
test/Conversion/tritongpu_to_llvm.mlir
Normal file
9
test/Conversion/tritongpu_to_llvm.mlir
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8
|
||||||
|
|
||||||
|
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||||
|
// CHECK: attributes {nvvm.maxntidx = 96 : i32}
|
||||||
|
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||||
|
|
||||||
|
// CHECK: llvm.return
|
||||||
|
return
|
||||||
|
}
|
Reference in New Issue
Block a user