[FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562)

This commit is contained in:
Keren Zhou
2022-07-13 15:52:21 -07:00
committed by GitHub
parent 971f5782b4
commit 4912916c11
24 changed files with 2634 additions and 64 deletions

View File

@@ -98,7 +98,7 @@ class CMakeBuild(build_ext):
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
python_include_dirs = [distutils.sysconfig.get_python_inc()]
cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",

View File

@@ -1,5 +1,6 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/target.h"
#include "triton/codegen/extern_lib.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/ir/builder.h"
@@ -19,7 +20,6 @@
#include <stdexcept>
#include <string>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace py = pybind11;
@@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
int num_warps, int num_stages) {
int num_warps, int num_stages, py::dict& extern_libs) {
size_t len = PyList_Size(args.ptr());
params.reserve(8*len); // 8 max bytes by argument
char* params_ptr = &params[0];
@@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
for (auto item : extern_libs) {
cache_key += "-" + item.first.cast<std::string>();
cache_key += "_" + item.second.cast<std::string>();
}
}
//
@@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) {
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
py::function add_to_cache, py::object grid){
py::dict extern_libs, py::function add_to_cache, py::object grid){
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
@@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) {
std::string params;
size_t params_size;
py::dict constants;
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
params_size, constants, _num_warps, _num_stages, extern_libs);
// get cached binary
py::str key(cache_key);
py::bool_ noop = false;
if(!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs);
}
if (noop)
return (py::object)py::none();
@@ -467,11 +473,10 @@ std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::st
// ---------------------------------------
// CUDA
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
int n_shared_bytes;
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
int num_stages, asm_map_t &asm_map,
const triton::codegen::ExternLibMap &extern_lib_map) {
py::gil_scoped_release allow_threads;
llvm::LLVMContext ctx;
// device properties
@@ -483,7 +488,9 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
@@ -502,14 +509,16 @@ std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name,
}
// HIP
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
int num_stages, asm_map_t &asm_map,
const triton::codegen::ExternLibMap &extern_lib_map) {
llvm::LLVMContext ctx;
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target;
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
auto llvm = triton::codegen::add_passes_to_emit_bin(
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
@@ -523,7 +532,9 @@ std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name
void init_triton_codegen(py::module &&m) {
m.def(
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
"compile_ttir",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps,
int num_stages, py::dict& extern_libs) {
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
asm_map_t asm_map;
@@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) {
ir.print(ttir);
asm_map["ttir"] = py::cast(ttir.str());
llvm::LLVMContext ctx;
// construct extern lib map
triton::codegen::ExternLibMap extern_lib_map;
for (auto item : extern_libs) {
auto name = item.first.cast<std::string>();
auto path = item.second.cast<std::string>();
extern_lib_map.emplace(
name, triton::codegen::create_extern_lib(name, path));
}
if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership);
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
},
py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::gil_scoped_release allow_threads;
if(backend == CUDA)
@@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) {
// Utilities
.def("create_clock", &ir::builder::create_clock, ret::reference)
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
// Extern instruction
.def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference)
// Built-in instruction
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)

View File

@@ -1300,3 +1300,49 @@ def test_num_warps_pow2():
_kernel[(1,)](dst=dst, num_warps=1)
_kernel[(1,)](dst=dst, num_warps=2)
_kernel[(1,)](dst=dst, num_warps=4)
# -------------
# test extern
# -------------
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', ''),
('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'),
('float64', 'libdevice.norm4d', '')])
def test_libdevice(dtype_str, expr, lib_path):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = GENERATE_TEST_HERE
tl.store(Y + tl.arange(0, BLOCK), y)
shape = (128, )
rs = RandomState(17)
# limit the range of integers so that the sum does not overflow
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
if expr == 'libdevice.ffs':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'})
y_ref = np.zeros(shape, dtype=x.dtype)
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
elif expr == 'libdevice.pow':
# numpy does not allow negative factors in power, so we use abs()
x = np.abs(x)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'})
y_ref = np.power(x, x)
elif expr == 'libdevice.norm4d':
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'})
y_ref = np.sqrt(4 * np.power(x, 2))
x_tri = to_triton(x)
# triton result
y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda')
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
if expr == 'libdevice.ffs':
np.testing.assert_equal(y_ref, to_numpy(y_tri))
else:
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)

View File

@@ -689,7 +689,7 @@ class CodeGenerator(ast.NodeVisitor):
ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
return ret
# built-in function
if sys.modules[fn.__module__] is triton.language.core:
if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction):
ret = fn(*args, _builder=self.builder, **kws)
if fn in self.value_constructor.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
@@ -933,7 +933,7 @@ class Kernel:
self.fn = fn
self.cache_key = {}
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
@@ -953,9 +953,10 @@ class Kernel:
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False)
return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages,
extern_libs=extern_libs, is_manual_warmup=False)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs):
assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2."
# handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
@@ -985,7 +986,7 @@ class Kernel:
cache_key = self.cache_key[device]
stream = current_cuda_stream(device)
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names,
device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache,
device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache,
grid)
@@ -1242,7 +1243,7 @@ class JITFunction:
def warmup(self, compile):
return self._warmup(**compile, is_manual_warmup=True)
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup):
def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup):
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory
@@ -1264,7 +1265,7 @@ class JITFunction:
with open(bin_cache_path, 'rb') as f:
binary = pickle.load(f)["binary"]
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages)
compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs)
if JITFunction.cache_hook is not None:
name = self.__name__
info = key.split('-')[-3:]
@@ -1293,7 +1294,7 @@ class JITFunction:
self.bin_cache[key] = LoadedBinary(device, binary)
return False
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs):
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
@@ -1316,7 +1317,7 @@ class JITFunction:
backend = _triton.runtime.backend.CUDA
else:
backend = _triton.runtime.backend.ROCM
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs)
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
if shared_mem > max_shared_memory:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")

View File

@@ -1,4 +1,4 @@
# flake8: noqa: F401
from . import core, random
from . import core, extern, libdevice, random
from .core import *
from .random import *

View File

@@ -248,8 +248,10 @@ class block_type(dtype):
# while tensor's shape is a list of constexpr
self.shape = shape
self.numel = 1
for s in self.shape:
self.numel *= s
for i, s in enumerate(self.shape):
if isinstance(s, constexpr):
self.shape[i] = s.value
self.numel *= self.shape[i]
self.name = self.__str__()

View File

@@ -0,0 +1,107 @@
from __future__ import annotations # remove after python 3.11
from . import core, semantic
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _builder: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")
num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")
arg_types = []
arg_list = []
for arg in args:
if isinstance(arg, core.tensor):
arg_types.append(arg.dtype)
arg_list.append(arg.handle)
else:
arg_types.append(type(arg))
arg_list.append(arg)
arg_types = tuple(arg_types)
if arg_types not in arg_type_symbol_dict:
raise ValueError(f"input arg type does not match."
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_type = arg_type_symbol_dict[arg_types][1]
ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
'''
Dispatch an elementwise function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param _builder: the builder
:return: the return value of the function
'''
dispatch_args = args.copy()
if len(args) == 1:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
ret_shape = dispatch_args[0].shape
elif len(args) == 2:
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
dispatch_args[0], dispatch_args[1], _builder)
ret_shape = dispatch_args[0].shape
else:
for i in range(len(dispatch_args)):
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
broadcast_arg = dispatch_args[0]
# Get the broadcast shape over all the arguments
for i in range(len(dispatch_args)):
_, broadcast_arg = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
# Change the shape of each argument based on the broadcast shape
for i in range(len(dispatch_args)):
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
dispatch_args[i], broadcast_arg, _builder)
ret_shape = broadcast_arg.shape
func = getattr(_builder, "create_extern_elementwise")
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
class ExternalFunction:
'''
A wrapper for external functions
'''
def __init__(self, fn):
self.fn = fn
def __call__(self, *args, **kwargs):
if '_builder' not in kwargs or \
kwargs['_builder'] is None:
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
return self.fn(*args, **kwargs)
def extern(fn):
'''
A decorator for external functions
'''
return ExternalFunction(fn)

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,340 @@
import argparse
import subprocess
from abc import ABC, abstractmethod
class Symbol:
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
'''
A symbol is a function declaration.
:param name: name of the symbol
:param op_name: name of the operation
:param ret_type: return type of the operation
:param arg_names: names of the arguments
:param arg_types: types of the arguments
'''
self._name = name
self._op_name = op_name
self._ret_type = ret_type
self._arg_names = arg_names
self._arg_types = arg_types
@property
def name(self):
return self._name
@property
def op_name(self):
return self._op_name
@property
def ret_type(self):
return self._ret_type
@property
def arg_names(self):
return self._arg_names
@property
def arg_types(self):
return self._arg_types
def convert_type(type_str):
if type_str == "i32":
return "int32"
elif type_str == "u32":
return "uint32"
elif type_str == "i64":
return "int64"
elif type_str == "u64":
return "uint64"
elif type_str == "float":
return "fp32"
elif type_str == "double":
return "fp64"
else:
# ignore other types, such as pointer types
return None
def to_unsigned(type_str):
if type_str == "int32":
return "uint32"
elif type_str == "int64":
return "uint64"
else:
return type_str
class ExternLibrary(ABC):
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
'''
Abstract class for extern library.
:param name: name of the library
:param path: path of the library
:param format: whether to format the generated stub file
'''
self._name = name
self._path = path
self._symbols = {}
self._format = True
self._grouping = grouping
@property
def name(self):
return self._name
@property
def path(self):
return self._path
@property
def symbols(self):
return self._symbols
@property
def grouping(self):
return self._grouping
@abstractmethod
def parse_symbols(self, input_file):
pass
@abstractmethod
def _output_stubs(self) -> str:
pass
def generate_stub_file(self, output_dir):
file_str = self._output_stubs()
if file_str is None or len(file_str) == 0:
raise Exception("file_str is empty")
output_file = f"{output_dir}/{self._name}.py"
with open(output_file, "w") as f:
f.write(file_str)
f.close()
if self._format:
subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file],
stdout=subprocess.PIPE).communicate()
subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
class Libdevice(ExternLibrary):
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
:param path: path of the libdevice library
'''
super().__init__("libdevice", path)
self._symbol_groups = {}
def _extract_symbol(self, line):
# Extract symbols from line in the following format:
# "define [internal] <ret_type> @<name>(<arg_types>,)"
entries = line.split("@")
ret_str = entries[0]
func_str = entries[1]
# Get ret_type, skip internal symbols
ret_strs = ret_str.split()
if ret_strs[1] == "internal":
return None
ret_type = convert_type(ret_strs[1])
if ret_type is None:
return None
# Get function name
func_strs = func_str.split("(")
func_name = func_strs[0].replace("@", "")
op_name = func_name.replace("__nv_", "")
# Get arg_types
arg_strs = func_strs[1].split(",")
arg_types = []
arg_names = []
for i, arg_str in enumerate(arg_strs):
arg_type = convert_type(arg_str.split()[0])
if arg_type is None:
return None
arg_name = 'arg' + str(i)
arg_types.append(arg_type)
arg_names.append(arg_name)
if op_name == "sad":
# Special case for sad, where the last argument is an unsigned int
arg_types[-1] = to_unsigned(arg_types[-1])
elif op_name.startswith("u"):
# LLVM does not differentiate between signed and unsigned integer type.
# We have to convert the types to unsigned
ret_type = to_unsigned(ret_type)
for i, arg_type in enumerate(arg_types):
arg_types[i] = to_unsigned(arg_type)
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
def _group_symbols(self):
symbol_set = {}
for symbol in self._symbols.values():
op_name = symbol.op_name
symbol_set[op_name] = symbol
# The following cases are grouped together:
# op_name, <u/ull/ll>op_name<ll/f/i>
for symbol in self._symbols.values():
op_name = symbol.op_name
if "max" in op_name:
op_name = "max"
elif "min" in op_name:
op_name = "min"
elif "abs" in op_name:
op_name = "abs"
elif "pow" in op_name and "fast" in op_name:
op_name = "pow"
elif "round" in op_name:
if "llround" in op_name:
op_name = "llround"
else:
op_name = "round"
elif "rint" in op_name:
if "llrint" in op_name:
op_name = "llrint"
else:
op_name = "rint"
elif op_name.startswith("ull"):
if "2" not in op_name:
# e.g., ullmax->max
op_name = op_name[3:]
else:
# e.g., ull2double->ll2double
op_name = op_name[1:]
elif op_name.startswith("u"):
if "2" not in op_name:
# e.g., uhadd->hadd
op_name = op_name[1:]
else:
# e.g., uint2double_rn->int2double_rn
op_name = op_name[1:]
elif op_name.startswith("ll"):
if "2" not in op_name:
# e.g., llmax->max
op_name = op_name[2:]
elif op_name.endswith("ll"):
op_name = op_name[:-2]
elif op_name.endswith("f"):
op_name = op_name[:-1]
if op_name in symbol_set:
# Update op_name only if there's an existing symbol
symbol._op_name = op_name
else:
op_name = symbol._op_name
if op_name in self._symbol_groups:
self._symbol_groups[op_name].append(symbol)
else:
self._symbol_groups[op_name] = [symbol]
def parse_symbols(self, input_file):
if len(self.symbols) > 0:
return
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
for line in output:
symbol = self._extract_symbol(line)
if symbol is None:
continue
self._symbols[symbol.name] = symbol
self._group_symbols()
def _output_stubs(self):
# Generate python functions in the following format:
# @extern.extern
# def <op_name>(<args>, _builder=None):
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
import_str = "from . import core, extern\n"
import_str += "import os\n"
header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n"
func_str = ""
for symbols in self._symbol_groups.values():
func_str += "@extern.extern\n"
func_name_str = f"def {symbols[0].op_name}("
for arg_name in symbols[0].arg_names:
func_name_str += f"{arg_name}, "
func_name_str += "_builder=None):\n"
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
for arg_name in symbols[0].arg_names:
return_str += f"{arg_name}, "
return_str += "], \n"
arg_type_symbol_dict_str = "{"
for symbol in symbols:
arg_type_symbol_dict_str += "("
for arg_type in symbol.arg_types:
arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\"),"
ret_type = f"core.dtype(\"{symbol.ret_type}\")"
arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
arg_type_symbol_dict_str += "}"
return_str += arg_type_symbol_dict_str
return_str += ", _builder)\n"
func_str += func_name_str + return_str + "\n"
file_str = import_str + header_str + func_str
return file_str
class LLVMDisassembler:
def __init__(self, path):
'''
Invoke llvm-dis to disassemble the given file.
:param path: path to llvm-dis
'''
self._path = path
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path):
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
@property
def ll_file(self):
return self._ll_file
@property
def path(self):
return self._path
extern_libs = ["libdevice"]
def build(llvm_dis_path, lib_path, lib_name, output_dir):
'''
Interface function to build the library file.
:param llvm_dis_path: path to the llvm-dis binary
:param lib_path: path to the external library file
:param lib_name: name of the library
:param output_dir: path to the output directory
'''
if lib_name == "libdevice":
extern_lib = Libdevice(lib_path)
else:
raise Exception(f"Unknown extern library: {lib_name}")
llvm_disassembler = LLVMDisassembler(llvm_dis_path)
llvm_disassembler.disasm(lib_path)
extern_lib.parse_symbols(llvm_disassembler.ll_file)
extern_lib.generate_stub_file(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis")
parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library")
parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library")
parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/")
args = parser.parse_args()
build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)

View File

@@ -0,0 +1,74 @@
"""
Libdevice function
===============
Triton can invoke a custom function from an external library.
In this example, we will use the `libdevice` library to apply `asin` on a tensor.
Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions.
In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together.
For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`.
Using triton, you can simply call `tl.libdevice.asinf`.
triton automatically selects the correct underlying device function to invoke based on input and output types.
"""
# %%
# asin Kernel
# --------------------------
import torch
import triton
import triton.language as tl
@triton.jit
def asin_kernel(
x_ptr,
y_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x = tl.libdevice.asin(x)
tl.store(y_ptr + offsets, x, mask=mask)
# %%
# Using the default libdevice library path
# --------------------------
# We can use the default libdevice library path encoded in `triton/language/libdevice.py`
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
output_triton = torch.zeros(size, device='cuda')
output_torch = torch.asin(x)
assert x.is_cuda and output_triton.is_cuda
n_elements = output_torch.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024)
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)
# %%
# Customize the libdevice library path
# --------------------------
# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel.
output_triton = torch.empty_like(x)
asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024,
extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'})
print(output_torch)
print(output_triton)
print(
f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}'
)