From b988bae813d9694d2119c9414b7b9b287ad2cb93 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 4 Aug 2022 10:15:45 +0800 Subject: [PATCH] Init TritonGPU to LLVM dialect conversion (#32) * add toLLVM pass * update num-warps setting in mlir --- bin/triton-opt.cpp | 1 + include/triton/Conversion/Passes.h | 1 + include/triton/Conversion/Passes.td | 21 ++ .../TritonGPUToLLVM/TritonGPUToLLVM.h | 29 +++ lib/Conversion/CMakeLists.txt | 2 +- lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 19 ++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 240 ++++++++++++++++++ python/src/triton.cc | 9 +- test/Conversion/tritongpu_to_llvm.mlir | 9 + 9 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h create mode 100644 lib/Conversion/TritonGPUToLLVM/CMakeLists.txt create mode 100644 lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp create mode 100644 test/Conversion/tritongpu_to_llvm.mlir diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 4942214cc..d476135d8 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -22,6 +22,7 @@ int main(int argc, char **argv) { mlir::registerTritonGPUPasses(); mlir::test::registerTestAlignmentPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerConvertTritonGPUToLLVMPass(); // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; diff --git a/include/triton/Conversion/Passes.h b/include/triton/Conversion/Passes.h index 8cf53bc1c..07aff36a4 100644 --- a/include/triton/Conversion/Passes.h +++ b/include/triton/Conversion/Passes.h @@ -1,6 +1,7 @@ #ifndef TRITON_CONVERSION_PASSES_H #define TRITON_CONVERSION_PASSES_H +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" namespace mlir { diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 72c330ba0..c365e0f02 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -24,4 +24,25 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO ]; } + +def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonGPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithmeticDialect", + "mlir::StandardOpsDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps"> + ]; +} + #endif diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h new file mode 100644 index 000000000..adbd2ef52 --- /dev/null +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -0,0 +1,29 @@ +#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_ +#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Transforms/DialectConversion.h" +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +class TritonLLVMConversionTarget : public ConversionTarget { + mlir::LLVMTypeConverter &typeConverter; + +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx, + mlir::LLVMTypeConverter &typeConverter); +}; + +namespace triton { + +std::unique_ptr> createConvertTritonGPUToLLVMPass(); + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index a08349513..143a4375a 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,2 +1,2 @@ -# add_subdirectory(TritonGPUToLLVM) add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..358604975 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(TritonGPUToLLVM + TritonGPUToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM + + DEPENDS + TritonConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TritonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp new file mode 100644 index 000000000..7ccdf7bc0 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -0,0 +1,240 @@ +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace LLVM { + +static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } + +} // namespace LLVM +} // namespace mlir + +namespace { + +// The following code are borrowed from mlir project including the following +// functions or classes: +// - filterFuncAttributes +// - ConvertOpToLLVMPattern +// - FuncOpConversion +// +// The code are hidden in the CPP files in MLIR repo, and we can't call them +// directly. I found such code snippets are refactored and added to LLVMCommon +// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton +// doesn't contain the code. +// TODO(Superjomn) Remove the code when mlir v15.0 is included. +// +// The original code: +// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219 +// All the rights are reserved by LLVM community. + +/// Only retain those attributes that are not constructed by +/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument +/// attributes. +static void filterFuncAttributes(ArrayRef attrs, + bool filterArgAttrs, + SmallVectorImpl &result) { + for (const auto &attr : attrs) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && + attr.getName() == FunctionOpInterface::getArgDictAttrName())) + continue; + result.push_back(attr); + } +} + +struct FuncOpConversionBase : public ConvertOpToLLVMPattern { +protected: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided + // to this legalization pattern. + LLVM::LLVMFuncOp + convertFuncOpToLLVMFuncOp(FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Convert the original function arguments. They are converted using the + // LLVMTypeConverter provided to this legalization pattern. + auto varargsAttr = funcOp->getAttrOfType("std.varargs"); + TypeConverter::SignatureConversion result(funcOp.getNumArguments()); + auto llvmType = getTypeConverter()->convertFunctionSignature( + funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); + assert(llvmType); + if (!llvmType) + return nullptr; + + // Propagate argument attributes to all converted arguments obtained after + // converting a given original argument. + SmallVector attributes; + filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, + attributes); + if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { + SmallVector newArgAttrs( + llvmType.cast().getNumParams()); + for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { + auto mapping = result.getInputMapping(i); + assert(mapping.hasValue() && + "unexpected deletion of function argument"); + for (size_t j = 0; j < mapping->size; ++j) + newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; + } + attributes.push_back( + rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), + rewriter.getArrayAttr(newArgAttrs))); + } + for (const auto &pair : llvm::enumerate(attributes)) { + if (pair.value().getName() == "llvm.linkage") { + attributes.erase(attributes.begin() + pair.index()); + break; + } + } + + // Create an LLVM function, use external linkage by default until MLIR + // functions have linkage. + LLVM::Linkage linkage = LLVM::Linkage::External; + if (funcOp->hasAttr("llvm.linkage")) { + auto attr = + funcOp->getAttr("llvm.linkage").dyn_cast(); + if (!attr) { + funcOp->emitError() + << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr"; + return nullptr; + } + linkage = attr.getLinkage(); + } + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), llvmType, linkage, + /*dsoLocal*/ false, attributes); + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, + &result))) + return nullptr; + + return newFuncOp; + } +}; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; +struct FuncOpConversion : public FuncOpConversionBase { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps) + : FuncOpConversionBase(converter), NumWarps(numWarps) { + mlir::ConvertToLLVMPattern::getTypeConverter()->addConversion( + [&](triton::PointerType type) { + return convertTritonPointerType(type); + }); + } + + Type convertTritonPointerType(triton::PointerType type) { + return LLVM::LLVMPointerType::get(type.getPointeeType(), + type.getAddressSpace()); + } + + LogicalResult + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) + return failure(); + + auto ctx = funcOp->getContext(); + auto i32 = IntegerType::get(ctx, 32); + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntidx", + rewriter.getIntegerAttr(i32, 32 * NumWarps)); + + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int NumWarps{-1}; +}; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + unsigned numArguments = op.getNumOperands(); + + // Currently, Triton kernel function always return nothing. + // TODO(Superjomn) add support for non-inline device function + if (numArguments > 0) { + return rewriter.notifyMatchFailure( + op, "Only kernel function with nothing returned is supported."); + } + + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + return success(); + } +}; + +} // namespace + +void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps) { + patterns.add<::FuncOpConversion>(typeConverter, numWarps); + patterns.add<::ReturnOpConversion>(typeConverter); +} + +class ConvertTritonGPUToLLVM + : public ConvertTritonGPUToLLVMBase { +public: + ConvertTritonGPUToLLVM() = default; + // For manually overwrite the numWarps option + explicit ConvertTritonGPUToLLVM(int numWarps) { this->numWarps = numWarps; } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + LLVMTypeConverter typeConverter(context); + TritonLLVMConversionTarget target(*context, typeConverter); + + RewritePatternSet patterns(context); + // Add arith's patterns to help convert scalar expression to LLVM. + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + patterns); + + populateTritonToLLVMPatterns(typeConverter, patterns, numWarps); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +namespace mlir { + +TritonLLVMConversionTarget::TritonLLVMConversionTarget( + MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) + : ConversionTarget(ctx), typeConverter(typeConverter) { + addLegalDialect(); +} + +namespace triton { + +std::unique_ptr> createConvertTritonGPUToLLVMPass() { + return std::make_unique<::ConvertTritonGPUToLLVM>(); +} + +} // namespace triton +} // namespace mlir diff --git a/python/src/triton.cc b/python/src/triton.cc index 9a7dbbd29..91c5ef739 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -11,6 +11,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -1634,8 +1635,12 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCombineOpsPass()); }) - .def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) { - self.addPass(mlir::createTritonGPUVerifier()); + .def("add_triton_gpu_verifier_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUVerifier()); + }) + .def("triton_gpu_to_llvm", [](mlir::PassManager &self) { + self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); }); } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir new file mode 100644 index 000000000..33cea027e --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -0,0 +1,9 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm=num-warps=8 + +// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr) +// CHECK: attributes {nvvm.maxntidx = 96 : i32} +func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + + // CHECK: llvm.return + return +}