2022-08-08 13:34:36 +08:00
|
|
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
|
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
|
|
|
#include "mlir/IR/AsmState.h"
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "mlir/IR/Dialect.h"
|
|
|
|
#include "mlir/Parser.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Pass/PassManager.h"
|
|
|
|
#include "mlir/Support/FileUtilities.h"
|
|
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
|
|
|
#include "mlir/Target/LLVMIR/Export.h"
|
|
|
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
|
|
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
|
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
|
|
|
#include "triton/driver/llvm.h"
|
|
|
|
#include "llvm/IR/LLVMContext.h"
|
|
|
|
#include "llvm/Support/CommandLine.h"
|
|
|
|
#include "llvm/Support/InitLLVM.h"
|
|
|
|
#include "llvm/Support/SourceMgr.h"
|
|
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace triton {
|
|
|
|
|
|
|
|
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
|
|
|
MLIRContext &context) {
|
|
|
|
std::string errorMessage;
|
|
|
|
auto input = openInputFile(inputFilename, &errorMessage);
|
|
|
|
if (!input) {
|
|
|
|
llvm::errs() << errorMessage << "\n";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::DialectRegistry registry;
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
|
|
|
mlir::math::MathDialect, arith::ArithmeticDialect,
|
|
|
|
StandardOpsDialect, scf::SCFDialect>();
|
2022-08-08 13:34:36 +08:00
|
|
|
|
|
|
|
context.appendDialectRegistry(registry);
|
|
|
|
|
|
|
|
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
|
|
|
|
-> OwningOpRef<ModuleOp> {
|
|
|
|
llvm::SourceMgr sourceMgr;
|
|
|
|
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
|
|
|
|
|
|
|
context.loadAllAvailableDialects();
|
|
|
|
context.allowUnregisteredDialects();
|
|
|
|
|
|
|
|
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
|
|
|
if (!module) {
|
|
|
|
llvm::errs() << "Parse MLIR file failed.";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
return module;
|
|
|
|
};
|
|
|
|
|
|
|
|
auto module = processBuffer(std::move(input));
|
|
|
|
if (!module) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
return module;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult tritonTranslateMain(int argc, char **argv,
|
|
|
|
llvm::StringRef toolName) {
|
|
|
|
static llvm::cl::opt<std::string> inputFilename(
|
|
|
|
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
|
|
|
llvm::cl::init("-"));
|
|
|
|
|
|
|
|
static llvm::cl::opt<std::string> outputFilename(
|
|
|
|
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
|
|
|
llvm::cl::init("-"));
|
|
|
|
|
|
|
|
static llvm::cl::opt<std::string> targetKind(
|
|
|
|
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
|
|
|
|
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
|
|
|
|
|
|
|
|
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
|
|
|
|
llvm::cl::init(80));
|
|
|
|
|
|
|
|
static llvm::cl::opt<int> ptxVersion(
|
|
|
|
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
|
|
|
|
|
|
|
|
llvm::InitLLVM y(argc, argv);
|
|
|
|
|
|
|
|
registerAsmPrinterCLOptions();
|
|
|
|
registerMLIRContextCLOptions();
|
|
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
|
|
|
|
|
|
|
mlir::MLIRContext context;
|
|
|
|
auto module = loadMLIRModule(inputFilename, context);
|
|
|
|
if (!module) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string errorMessage;
|
|
|
|
auto output = openOutputFile(outputFilename, &errorMessage);
|
|
|
|
if (!output) {
|
|
|
|
llvm::errs() << errorMessage << "\n";
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::LLVMContext llvmContext;
|
2022-11-24 14:05:54 -08:00
|
|
|
auto llvmir =
|
|
|
|
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
|
2022-08-08 13:34:36 +08:00
|
|
|
if (!llvmir) {
|
|
|
|
llvm::errs() << "Translate to LLVM IR failed";
|
|
|
|
}
|
|
|
|
|
|
|
|
if (targetKind == "llvmir")
|
|
|
|
llvm::outs() << *llvmir << '\n';
|
|
|
|
else if (targetKind == "ptx")
|
|
|
|
llvm::outs() << ::triton::driver::llir_to_ptx(
|
|
|
|
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace triton
|
|
|
|
} // namespace mlir
|
|
|
|
|
|
|
|
int main(int argc, char **argv) {
|
|
|
|
return failed(mlir::triton::tritonTranslateMain(
|
|
|
|
argc, argv, "Triton Translate Testing Tool."));
|
|
|
|
}
|