Files
triton/bin/triton-translate.cpp
Shintaro Iwasaki 3c635449e5 [Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
2022-09-01 16:34:27 -07:00

141 lines
4.2 KiB
C++

#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/llvm.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>
namespace mlir {
namespace triton {
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
MLIRContext &context) {
std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
mlir::DialectRegistry registry;
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, arith::ArithmeticDialect,
StandardOpsDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
-> OwningOpRef<ModuleOp> {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
if (!module) {
llvm::errs() << "Parse MLIR file failed.";
return nullptr;
}
return module;
};
auto module = processBuffer(std::move(input));
if (!module) {
return nullptr;
}
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
pm.addPass(createConvertTritonGPUToLLVMPass());
if (failed(pm.run(module->getOperation()))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
return module;
}
LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::StringRef toolName) {
static llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> targetKind(
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
llvm::cl::init(80));
static llvm::cl::opt<int> ptxVersion(
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
mlir::MLIRContext context;
auto module = loadMLIRModule(inputFilename, context);
if (!module) {
return failure();
}
std::string errorMessage;
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
llvm::LLVMContext llvmContext;
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
if (targetKind == "llvmir")
llvm::outs() << *llvmir << '\n';
else if (targetKind == "ptx")
llvm::outs() << ::triton::driver::llir_to_ptx(
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
return success();
}
} // namespace triton
} // namespace mlir
int main(int argc, char **argv) {
return failed(mlir::triton::tritonTranslateMain(
argc, argv, "Triton Translate Testing Tool."));
}