[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) - Currently till TritonGPU. It cannot be lowered to PTX now. - No special optimizations (e.g., constant folding etc) are applied. - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td - No constant folding etc for `libdevice` ops. ```py import triton import triton.language as tl import sys @triton.jit def add_kernel( x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, ): offsets = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets) x = tl.sin(x) output = tl.libdevice.sin(x) output = tl.libdevice.fdiv_rn(output, output) output = tl.libdevice.fmaf_rd(output, output, output) tl.store(y_ptr + offsets, output) if __name__ == "__main__" and len(sys.argv) >= 2: signature = "*fp32,*fp32" constants = {'BLOCK_SIZE': 1024} output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir") print(output) ``` -> ```llvm #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> %4 = math.sin %3 : tensor<1024xf32, #blocked> %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked> tt.store %9, %7 : tensor<1024xf32, #blocked> return } } ```
This commit is contained in:
@@ -1542,7 +1542,16 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||
ptr, val, mask);
|
||||
})
|
||||
|
||||
// External
|
||||
.def("create_external_elementwise",
|
||||
[](mlir::OpBuilder &self, const std::string &libName,
|
||||
const std::string &libPath, const std::string &symbol,
|
||||
std::vector<mlir::Value> &argList,
|
||||
mlir::Type retType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::ExtElemwiseOp>(
|
||||
loc, retType, argList, libName, libPath, symbol);
|
||||
})
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id",
|
||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||
@@ -1563,12 +1572,32 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||
allowTF32);
|
||||
})
|
||||
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
// .def("create_sin", &ir::builder::create_sin, ret::reference)
|
||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
||||
.def("create_exp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::ExpOp>(loc, val);
|
||||
})
|
||||
.def("create_cos",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::CosOp>(loc, val);
|
||||
})
|
||||
.def("create_sin",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::SinOp>(loc, val);
|
||||
})
|
||||
.def("create_log",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::LogOp>(loc, val);
|
||||
})
|
||||
.def("create_sqrt",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::math::SqrtOp>(loc, val);
|
||||
})
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
.def("create_reduce",
|
||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||
|
Reference in New Issue
Block a user