[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
self.debug = True
|
||||
# self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
|
@@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("get_int_width", &ir::type::get_integer_bitwidth)
|
||||
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||
@@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
.def("is_struct", &ir::type::is_struct_ty)
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
@@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type");
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type")
|
||||
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
|
||||
.def_property_readonly("arg_tys", [](ir::function_type* self){
|
||||
return std::vector<ir::type*>(self->params_begin(), self->params_end());
|
||||
});
|
||||
|
||||
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
||||
|
||||
py::class_<ir::block_type, ir::type>(m, "block_type")
|
||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||
|
||||
py::class_<ir::struct_type, ir::type>(m, "struct_type")
|
||||
.def("get", &ir::struct_type::get, ret::reference)
|
||||
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
|
||||
|
||||
py::class_<ir::value_constructor>(m, "value_constructor")
|
||||
.def(py::init<ir::builder&>())
|
||||
.def("seal_block", &ir::value_constructor::seal_block)
|
||||
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
|
||||
.def("set_type", &ir::value_constructor::set_type)
|
||||
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
|
||||
.def("get_values", &ir::value_constructor::get_values, ret::reference)
|
||||
.def("set_values", &ir::value_constructor::set_values);
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("has_function", &ir::module::has_function)
|
||||
.def("get_function", &ir::module::get_function, ret::reference)
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("seal_block", &ir::module::seal_block)
|
||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
||||
.def("set_type", &ir::module::set_type)
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("get_values", &ir::module::get_values, ret::reference)
|
||||
.def("set_values", &ir::module::set_values)
|
||||
.def("get_types", &ir::module::get_types, ret::reference)
|
||||
.def("set_types", &ir::module::set_types)
|
||||
.def("reset_ret_ty", &ir::module::reset_ret_ty)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
@@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("not_implemented", eattr::not_implemented);
|
||||
|
||||
py::class_<ir::attribute>(m, "attribute")
|
||||
.def(py::init<eattr, int>());
|
||||
.def(py::init<eattr, int>())
|
||||
.def_property_readonly("value", &ir::attribute::get_value);
|
||||
|
||||
py::class_<ir::function>(m, "function")
|
||||
.def_property_readonly("args", &ir::function::args)
|
||||
.def_property_readonly("attrs", &ir::function::attrs)
|
||||
.def("add_attr", &ir::function::add_attr);
|
||||
.def("set_is_kernel", &ir::function::set_is_kernel)
|
||||
.def("add_attr", &ir::function::add_attr)
|
||||
.def("has_attr", &ir::function::has_attr)
|
||||
.def("get_attrs", &ir::function::get_attributes);
|
||||
|
||||
py::class_<ir::argument, ir::value>(m, "argument");
|
||||
py::class_<ir::argument, ir::value>(m, "argument")
|
||||
.def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
|
||||
.def_property_readonly("arg_no", &ir::argument::get_arg_no);
|
||||
|
||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||
.def("create", &ir::basic_block::create, ret::reference)
|
||||
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder::iterator>(m, "bb_iterator");
|
||||
|
||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||
.def(py::init<ir::context &>())
|
||||
// getters
|
||||
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
|
||||
// control flow
|
||||
.def("call", &ir::builder::create_call, ret::reference)
|
||||
.def("launch", &ir::builder::create_launch, ret::reference)
|
||||
.def("br", &ir::builder::create_br, ret::reference)
|
||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
.def("ret", &ir::builder::create_ret, ret::reference)
|
||||
.def("get_insert_point", &ir::builder::get_insert_point)
|
||||
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
|
||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
// struct
|
||||
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
|
||||
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
|
||||
// constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
|
@@ -585,7 +585,6 @@ def test_f8_f16_roundtrip():
|
||||
|
||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
print(f16.dtype, f8_output.dtype)
|
||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
@@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
# -------------------------
|
||||
# test dynamic parallelism
|
||||
# -------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mult(x, alpha):
|
||||
tl.store(x + tl.program_id(0), alpha)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def stub(X, alpha, grid_0, grid_1, grid_2):
|
||||
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
|
||||
|
||||
|
||||
def test_dyn_par(cond=True, device='cuda'):
|
||||
n_pids = 10
|
||||
# pids = torch.arange(n_pids, device=device)
|
||||
# alpha = 2.0
|
||||
# x_ref = pids * alpha
|
||||
x_tri = torch.full((10,), fill_value=-1., device=device)
|
||||
# cond = torch.tensor([cond], device=device)
|
||||
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
|
||||
print(x_tri)
|
||||
# triton.testing.assert_almost_equal(x_ref, x_tri)
|
||||
|
@@ -21,6 +21,41 @@ import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
def mangle_ty(type):
|
||||
if type.is_ptr():
|
||||
return 'P' + mangle_ty(type.element)
|
||||
if type.is_int():
|
||||
return 'i' + str(type.get_int_width())
|
||||
if type.is_fp8():
|
||||
return 'fp8'
|
||||
if type.is_fp16():
|
||||
return 'fp16'
|
||||
if type.is_bf16():
|
||||
return 'bf16'
|
||||
if type.is_fp32():
|
||||
return 'fp32'
|
||||
if type.is_fp64():
|
||||
return 'fp64'
|
||||
if type.is_void():
|
||||
return 'V'
|
||||
if type.is_block():
|
||||
elt = mangle_ty(type.scalar)
|
||||
shape = '_'.join(map(str, type.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
assert False, "Unsupport type"
|
||||
|
||||
|
||||
def mangle_fn(name, arg_tys, constants):
|
||||
# doesn't mangle ret type, which must be a function of arg tys
|
||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def get_value(self, name):
|
||||
# search node.id in local scope
|
||||
@@ -36,7 +71,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.language.block):
|
||||
handle = self.module.get_value(name)
|
||||
handle = self.value_constructor.get_value(name)
|
||||
return triton.language.block(handle)
|
||||
return ret
|
||||
|
||||
@@ -44,8 +79,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.language.block(value)
|
||||
if isinstance(value, triton.language.block):
|
||||
self.module.set_value(name, value.handle)
|
||||
self.module.set_type(name, value.handle.type)
|
||||
self.value_constructor.set_value(name, value.handle)
|
||||
self.value_constructor.set_type(name, value.handle.type)
|
||||
self.lscope[name] = value
|
||||
|
||||
def is_triton_object(self, value):
|
||||
@@ -58,16 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = _triton.ir.module('', self.builder)
|
||||
self.value_constructor = _triton.ir.value_constructor(self.builder)
|
||||
self.module = _triton.ir.module('', self.builder) if module is None else module
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
self.last_node = None
|
||||
self.is_kernel = is_kernel
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
@@ -92,9 +128,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = self.visit(node.value)
|
||||
if ret is None:
|
||||
return self.builder.ret_void()
|
||||
return ret
|
||||
if isinstance(ret, _triton.ir.value):
|
||||
ret = self.builder.ret(ret)
|
||||
return ret
|
||||
if isinstance(ret, triton.language.block):
|
||||
ret = ret.handle
|
||||
if isinstance(ret, triton.language.constexpr):
|
||||
ret = triton.language.core._to_ir(ret, self.builder)
|
||||
# TODO: should return tl.block
|
||||
return self.builder.ret(ret)
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
@@ -107,45 +151,44 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# store keyword arguments in local scope
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
if inline:
|
||||
pass
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[idx].type.is_ptr()
|
||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants)
|
||||
fn = self.module.get_or_insert_function(fn_name, self.prototype)
|
||||
fn.set_is_kernel(self.is_kernel)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[idx].type.is_ptr()
|
||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
insert_pt = self.builder.get_insert_block()
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.builder.set_insert_block(entry)
|
||||
self.value_constructor.seal_block(entry)
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.module.seal_block(entry)
|
||||
self.builder.set_insert_block(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
# visit function body
|
||||
has_ret = self.visit_compound_statement(node.body)
|
||||
# finalize
|
||||
if not has_ret:
|
||||
self.builder.ret_void()
|
||||
else:
|
||||
self.module.reset_ret_ty(fn_name, self.last_ret.type)
|
||||
# self.module.reset_ret_type(node.name)
|
||||
self.builder.set_insert_block(insert_pt)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
@@ -186,6 +229,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
if isinstance(values[0], _triton.ir.value):
|
||||
struct = values[0]
|
||||
ty = struct.type
|
||||
if ty.is_struct():
|
||||
values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)]
|
||||
assert len(values) == len(names)
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
@@ -215,6 +264,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
args = [self.visit(x) for x in node.elts]
|
||||
mode = type(args[0])
|
||||
# tuple of values -- create a struct
|
||||
if len(args) > 1 and mode == triton.language.block\
|
||||
and all([type(arg) == mode for arg in args]):
|
||||
args = [arg.handle for arg in args]
|
||||
tys = [arg.type for arg in args]
|
||||
struct_ty = _triton.ir.struct_type.get(tys, True)
|
||||
ret = _triton.ir.undef.get(struct_ty)
|
||||
for i, arg in enumerate(args):
|
||||
ret = self.builder.insert_value(ret, arg, i)
|
||||
return ret
|
||||
return tuple(args)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
@@ -254,9 +314,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
self.module.seal_block(then_bb)
|
||||
self.value_constructor.seal_block(then_bb)
|
||||
if else_bb:
|
||||
self.module.seal_block(else_bb)
|
||||
self.value_constructor.seal_block(else_bb)
|
||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
@@ -271,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
self.value_constructor.seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
@@ -350,9 +410,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.value_constructor.seal_block(stop_bb)
|
||||
self.value_constructor.seal_block(loop_bb)
|
||||
self.value_constructor.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -421,9 +481,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.value_constructor.seal_block(stop_bb)
|
||||
self.value_constructor.seal_block(loop_bb)
|
||||
self.value_constructor.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -449,15 +509,62 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.block)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in arg_vals]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
ret_type = _triton.ir.type.get_void(self.builder.context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module)
|
||||
generator.visit(fn.parse())
|
||||
symbol = self.module.get_function(fn_name)
|
||||
ret = self.builder.call(symbol, arg_vals)
|
||||
if not ret.type.is_void() and not ret.type.is_struct():
|
||||
ret = triton.language.block(ret)
|
||||
return ret
|
||||
# built-in function
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
ret = fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||
for arg in args]
|
||||
return fn(*args, **kws)
|
||||
ret = fn(*args, **kws)
|
||||
# special case: dynamic parallelism
|
||||
# in this case the core primitive returns a proxy
|
||||
# if isinstance(ret, triton.language.core.LaunchProxy):
|
||||
# ret_type = _triton.ir.type.get_void(self.builder.context)
|
||||
# arg_tys = [x.type for x in ret.args]
|
||||
# prototype = _triton.ir.type.make_function(ret_type, arg_tys)
|
||||
# gscope = sys.modules[ret.fn.fn.__module__].__dict__
|
||||
# constants = ret.constants
|
||||
# fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants)
|
||||
# # TODO: clean-up attributes handling in function
|
||||
# if not self.module.has_function(fn_name):
|
||||
# attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \
|
||||
# if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) }
|
||||
# generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True)
|
||||
# generator.visit(ret.fn.parse())
|
||||
# symbol = self.module.get_function(fn_name)
|
||||
# # TODO: should ret.args not include any constants ?
|
||||
# ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps)
|
||||
return ret
|
||||
# return fn(*args, **kws)
|
||||
|
||||
def visit_Constant(self, node):
|
||||
return triton.language.constexpr(node.value)
|
||||
@@ -669,6 +776,7 @@ class Kernel:
|
||||
|
||||
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
|
||||
# attributes
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
@@ -881,7 +989,7 @@ class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
@@ -890,6 +998,7 @@ class JITFunction:
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.inline = inline
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
@@ -904,6 +1013,8 @@ class JITFunction:
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -930,31 +1041,8 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||
try:
|
||||
from inspect import getcallargs
|
||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||
arg_values = [arg_values[name] for name in self.arg_names]
|
||||
arg_values = [arg if isinstance(arg, triton.language.block)
|
||||
else triton.language.constexpr(arg) for arg in arg_values]
|
||||
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
generator.module.set_values(values)
|
||||
generator.module.set_types(types)
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.")
|
||||
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
@@ -1039,7 +1127,7 @@ class JITFunction:
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
@@ -1199,9 +1287,21 @@ def jit(*args, **kwargs):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
|
||||
######
|
||||
|
||||
# class ForwardDeclaration:
|
||||
|
||||
# def __init__(self, name, ret_ty, arg_tys) -> None:
|
||||
# self.name = name
|
||||
# self.ret_ty = ret_ty
|
||||
# self.arg_tys = arg_tys
|
||||
|
||||
# def forward_declare(name, ret_ty, arg_tys):
|
||||
# return ForwardDeclaration(name, ret_ty, arg_tys)
|
||||
|
||||
######
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
@@ -888,7 +888,7 @@ def sigmoid(x):
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
def softmax(x, ieee_rounding: constexpr = False):
|
||||
z = x - triton.language.max(x, 0)
|
||||
num = triton.language.exp(z)
|
||||
den = triton.language.sum(num, 0)
|
||||
@@ -942,3 +942,26 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
@triton.jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
# -----------------------
|
||||
# Dynamic Parallelism
|
||||
# -----------------------
|
||||
|
||||
|
||||
class LaunchProxy:
|
||||
|
||||
def __init__(self, fn, args, constants, grid, num_warps) -> None:
|
||||
self.args = args
|
||||
self.grid = grid
|
||||
self.constants = constants
|
||||
self.num_warps = num_warps
|
||||
self.fn = fn
|
||||
|
||||
|
||||
@builtin
|
||||
def launch(fn, args, grid, num_warps=None, _builder=None):
|
||||
constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)}
|
||||
args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)]
|
||||
grid = [_to_ir(x, builder=_builder) for x in grid]
|
||||
if num_warps is None:
|
||||
num_warps = _to_ir(4, builder=_builder)
|
||||
return LaunchProxy(fn, args, constants, grid, num_warps)
|
||||
|
Reference in New Issue
Block a user