[Triton] Support math and libdevice ops (#91)

This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
This commit is contained in:
Shintaro Iwasaki
2022-09-01 16:34:27 -07:00
committed by GitHub
parent 328b87aec6
commit 3c635449e5
18 changed files with 1938 additions and 51 deletions

View File

@@ -1542,7 +1542,16 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
ptr, val, mask);
})
// External
.def("create_external_elementwise",
[](mlir::OpBuilder &self, const std::string &libName,
const std::string &libPath, const std::string &symbol,
std::vector<mlir::Value> &argList,
mlir::Type retType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::ExtElemwiseOp>(
loc, retType, argList, libName, libPath, symbol);
})
// Built-in instruction
.def("create_get_program_id",
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
@@ -1563,12 +1572,32 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32);
})
// .def("create_exp", &ir::builder::create_exp, ret::reference)
// .def("create_cos", &ir::builder::create_cos, ret::reference)
// .def("create_sin", &ir::builder::create_sin, ret::reference)
// .def("create_log", &ir::builder::create_log, ret::reference)
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::ExpOp>(loc, val);
})
.def("create_cos",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::CosOp>(loc, val);
})
.def("create_sin",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::SinOp>(loc, val);
})
.def("create_log",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::LogOp>(loc, val);
})
.def("create_sqrt",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::math::SqrtOp>(loc, val);
})
// .def("create_trans", &ir::builder::create_trans, ret::reference)
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
.def("create_reduce",
[](mlir::OpBuilder &self, mlir::Value &operand,
mlir::triton::RedOp redOp, int axis) -> mlir::Value {

View File

@@ -0,0 +1,33 @@
import triton
import triton.language as tl
@triton.jit
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
y1 = tl.sin(x1)
y2 = tl.libdevice.sin(x2)
y3 = tl.libdevice.fdiv_rn(x3, x3)
y4 = tl.libdevice.fmaf_rd(x4, x4, x4)
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
def test_empty_kernel_cubin_compile():
kernel = triton.compile(math_kernel,
"*fp32,*fp32,*fp32,*fp32,i32",
device=0,
constants={"BLOCK_SIZE": 256},
output="ttgir") # "cubin"
assert kernel
# TODO: Check if the values are correct.
# TODO: Cover all the math operators

View File

@@ -671,7 +671,8 @@ class CodeGenerator(ast.NodeVisitor):
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
return tuple(results)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
sys.modules[fn.__module__] is triton.language.core or \
isinstance(fn, triton.language.extern.ExternalFunction):
return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg

View File

@@ -1,4 +1,4 @@
# flake8: noqa: F401
from . import core, random
from . import core, extern, libdevice, random
from .core import *
from .random import *

View File

@@ -234,11 +234,15 @@ class pointer_type(dtype):
class block_type(dtype):
def __init__(self, element_ty: dtype, shape: List[int]):
def __init__(self, element_ty: dtype, shape: List):
self.element_ty = element_ty
# FIXME:
# block_type's shape is a list of int
# while tensor's shape is a list of constexpr
# Note that block_type's shape is a list of int
# while tensor's shape is a list of constexpr.
assert shape
if isinstance(shape[0], constexpr):
shape = [s.value for s in shape]
self.shape = shape
self.numel = 1
for s in self.shape:

View File

@@ -0,0 +1,104 @@
from __future__ import annotations # remove after python 3.11
from . import core, semantic
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _builder: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")
num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")
arg_types = []
arg_list = []
for arg in args:
if isinstance(arg, core.tensor):
arg_types.append(arg.dtype)
arg_list.append(arg.handle)
else:
arg_types.append(type(arg))
arg_list.append(arg)
arg_types = tuple(arg_types)
if arg_types not in arg_type_symbol_dict:
raise ValueError(f"input arg type does not match."
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_type = arg_type_symbol_dict[arg_types][1]
if ret_shape:
ret_type = core.block_type(ret_type, ret_shape)
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
if len(args) == 1:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
ret_shape = dispatch_args[0].shape
elif len(args) == 2:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
dispatch_args[0], dispatch_args[1], _builder)
ret_shape = dispatch_args[0].shape
else:
for i in range(len(dispatch_args)):
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i in range(len(dispatch_args)):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_external_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

File diff suppressed because it is too large Load Diff