This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) - Currently till TritonGPU. It cannot be lowered to PTX now. - No special optimizations (e.g., constant folding etc) are applied. - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td - No constant folding etc for `libdevice` ops. ```py import triton import triton.language as tl import sys @triton.jit def add_kernel( x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, ): offsets = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets) x = tl.sin(x) output = tl.libdevice.sin(x) output = tl.libdevice.fdiv_rn(output, output) output = tl.libdevice.fmaf_rd(output, output, output) tl.store(y_ptr + offsets, output) if __name__ == "__main__" and len(sys.argv) >= 2: signature = "*fp32,*fp32" constants = {'BLOCK_SIZE': 1024} output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir") print(output) ``` -> ```llvm #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> %4 = math.sin %3 : tensor<1024xf32, #blocked> %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked> tt.store %9, %7 : tensor<1024xf32, #blocked> return } } ```
141 lines
4.2 KiB
C++
141 lines
4.2 KiB
C++
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
|
#include "mlir/Target/LLVMIR/Export.h"
|
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
|
#include "triton/driver/llvm.h"
|
|
#include "llvm/IR/LLVMContext.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/InitLLVM.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
#include <iostream>
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
|
|
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
|
MLIRContext &context) {
|
|
std::string errorMessage;
|
|
auto input = openInputFile(inputFilename, &errorMessage);
|
|
if (!input) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return nullptr;
|
|
}
|
|
|
|
mlir::DialectRegistry registry;
|
|
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
|
mlir::math::MathDialect, arith::ArithmeticDialect,
|
|
StandardOpsDialect, scf::SCFDialect>();
|
|
|
|
context.appendDialectRegistry(registry);
|
|
|
|
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
|
|
-> OwningOpRef<ModuleOp> {
|
|
llvm::SourceMgr sourceMgr;
|
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
|
|
|
context.loadAllAvailableDialects();
|
|
context.allowUnregisteredDialects();
|
|
|
|
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
|
if (!module) {
|
|
llvm::errs() << "Parse MLIR file failed.";
|
|
return nullptr;
|
|
}
|
|
|
|
return module;
|
|
};
|
|
|
|
auto module = processBuffer(std::move(input));
|
|
if (!module) {
|
|
return nullptr;
|
|
}
|
|
|
|
mlir::PassManager pm(module->getContext());
|
|
applyPassManagerCLOptions(pm);
|
|
|
|
pm.addPass(createConvertTritonGPUToLLVMPass());
|
|
|
|
if (failed(pm.run(module->getOperation()))) {
|
|
llvm::errs() << "Pass execution failed";
|
|
return nullptr;
|
|
}
|
|
|
|
return module;
|
|
}
|
|
|
|
LogicalResult tritonTranslateMain(int argc, char **argv,
|
|
llvm::StringRef toolName) {
|
|
static llvm::cl::opt<std::string> inputFilename(
|
|
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
|
llvm::cl::init("-"));
|
|
|
|
static llvm::cl::opt<std::string> outputFilename(
|
|
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
|
llvm::cl::init("-"));
|
|
|
|
static llvm::cl::opt<std::string> targetKind(
|
|
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
|
|
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
|
|
|
|
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
|
|
llvm::cl::init(80));
|
|
|
|
static llvm::cl::opt<int> ptxVersion(
|
|
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
|
|
|
|
llvm::InitLLVM y(argc, argv);
|
|
|
|
registerAsmPrinterCLOptions();
|
|
registerMLIRContextCLOptions();
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
|
|
|
mlir::MLIRContext context;
|
|
auto module = loadMLIRModule(inputFilename, context);
|
|
if (!module) {
|
|
return failure();
|
|
}
|
|
|
|
std::string errorMessage;
|
|
auto output = openOutputFile(outputFilename, &errorMessage);
|
|
if (!output) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return failure();
|
|
}
|
|
|
|
llvm::LLVMContext llvmContext;
|
|
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
|
|
if (!llvmir) {
|
|
llvm::errs() << "Translate to LLVM IR failed";
|
|
}
|
|
|
|
if (targetKind == "llvmir")
|
|
llvm::outs() << *llvmir << '\n';
|
|
else if (targetKind == "ptx")
|
|
llvm::outs() << ::triton::driver::llir_to_ptx(
|
|
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace triton
|
|
} // namespace mlir
|
|
|
|
int main(int argc, char **argv) {
|
|
return failed(mlir::triton::tritonTranslateMain(
|
|
argc, argv, "Triton Translate Testing Tool."));
|
|
}
|