[LANG] Added support for device functions (#484)

This commit is contained in:
Philippe Tillet
2022-04-03 20:58:16 -07:00
committed by GitHub
parent e85c7a7fc7
commit 2bed6fc850
39 changed files with 1213 additions and 379 deletions

View File

@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
self.debug = True
# self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'

View File

@@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("get_int_width", &ir::type::get_integer_bitwidth)
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference)
@@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) {
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("is_struct", &ir::type::is_struct_ty)
.def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
@@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::function_type, ir::type>(m, "function_type")
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
.def_property_readonly("arg_tys", [](ir::function_type* self){
return std::vector<ir::type*>(self->params_begin(), self->params_end());
});
py::class_<ir::integer_type, ir::type>(m, "integer_type");
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::struct_type, ir::type>(m, "struct_type")
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
py::class_<ir::value_constructor>(m, "value_constructor")
.def(py::init<ir::builder&>())
.def("seal_block", &ir::value_constructor::seal_block)
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
.def("set_type", &ir::value_constructor::set_type)
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
.def("get_values", &ir::value_constructor::get_values, ret::reference)
.def("set_values", &ir::value_constructor::set_values);
py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>())
.def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("seal_block", &ir::module::seal_block)
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
.def("set_type", &ir::module::set_type)
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("get_values", &ir::module::get_values, ret::reference)
.def("set_values", &ir::module::set_values)
.def("get_types", &ir::module::get_types, ret::reference)
.def("set_types", &ir::module::set_types)
.def("reset_ret_ty", &ir::module::reset_ret_ty)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;
@@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) {
.value("not_implemented", eattr::not_implemented);
py::class_<ir::attribute>(m, "attribute")
.def(py::init<eattr, int>());
.def(py::init<eattr, int>())
.def_property_readonly("value", &ir::attribute::get_value);
py::class_<ir::function>(m, "function")
.def_property_readonly("args", &ir::function::args)
.def_property_readonly("attrs", &ir::function::attrs)
.def("add_attr", &ir::function::add_attr);
.def("set_is_kernel", &ir::function::set_is_kernel)
.def("add_attr", &ir::function::add_attr)
.def("has_attr", &ir::function::has_attr)
.def("get_attrs", &ir::function::get_attributes);
py::class_<ir::argument, ir::value>(m, "argument");
py::class_<ir::argument, ir::value>(m, "argument")
.def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
.def_property_readonly("arg_no", &ir::argument::get_arg_no);
py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference)
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder::iterator>(m, "bb_iterator");
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
// control flow
.def("call", &ir::builder::create_call, ret::reference)
.def("launch", &ir::builder::create_launch, ret::reference)
.def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("ret", &ir::builder::create_ret, ret::reference)
.def("get_insert_point", &ir::builder::get_insert_point)
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
// struct
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
// constants
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)

View File

@@ -585,7 +585,6 @@ def test_f8_f16_roundtrip():
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
print(f16.dtype, f8_output.dtype)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor)
@@ -1009,8 +1008,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
assert ir_value_type == value_type
@@ -1031,3 +1030,28 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)
# -------------------------
# test dynamic parallelism
# -------------------------
@triton.jit
def mult(x, alpha):
tl.store(x + tl.program_id(0), alpha)
@triton.jit
def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
def test_dyn_par(cond=True, device='cuda'):
n_pids = 10
# pids = torch.arange(n_pids, device=device)
# alpha = 2.0
# x_ref = pids * alpha
x_tri = torch.full((10,), fill_value=-1., device=device)
# cond = torch.tensor([cond], device=device)
stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
print(x_tri)
# triton.testing.assert_almost_equal(x_ref, x_tri)

View File

@@ -21,6 +21,41 @@ import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def mangle_ty(type):
if type.is_ptr():
return 'P' + mangle_ty(type.element)
if type.is_int():
return 'i' + str(type.get_int_width())
if type.is_fp8():
return 'fp8'
if type.is_fp16():
return 'fp16'
if type.is_bf16():
return 'bf16'
if type.is_fp32():
return 'fp32'
if type.is_fp64():
return 'fp64'
if type.is_void():
return 'V'
if type.is_block():
elt = mangle_ty(type.scalar)
shape = '_'.join(map(str, type.shape))
return f'{elt}S{shape}S'
assert False, "Unsupport type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class CodeGenerator(ast.NodeVisitor):
def get_value(self, name):
# search node.id in local scope
@@ -36,7 +71,7 @@ class CodeGenerator(ast.NodeVisitor):
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.language.block):
handle = self.module.get_value(name)
handle = self.value_constructor.get_value(name)
return triton.language.block(handle)
return ret
@@ -44,8 +79,8 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(value, _triton.ir.value):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.set_type(name, value.handle.type)
self.value_constructor.set_value(name, value.handle)
self.value_constructor.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
@@ -58,16 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
break
return stmts and isinstance(stmt, ast.Return)
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False):
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder)
self.value_constructor = _triton.ir.value_constructor(self.builder)
self.module = _triton.ir.module('', self.builder) if module is None else module
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.is_kernel = is_kernel
self.builtins = {
'range': range,
'min': triton.language.minimum,
@@ -92,9 +128,17 @@ class CodeGenerator(ast.NodeVisitor):
ret = self.visit(node.value)
if ret is None:
return self.builder.ret_void()
return ret
if isinstance(ret, _triton.ir.value):
ret = self.builder.ret(ret)
return ret
if isinstance(ret, triton.language.block):
ret = ret.handle
if isinstance(ret, triton.language.constexpr):
ret = triton.language.core._to_ir(ret, self.builder)
# TODO: should return tl.block
return self.builder.ret(ret)
def visit_FunctionDef(self, node, inline=False, arg_values=None):
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
@@ -107,45 +151,44 @@ class CodeGenerator(ast.NodeVisitor):
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
pass
else:
fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
if i in self.attributes:
is_ptr = fn.args[idx].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name
arg_values.append(fn.args[idx])
idx += 1
fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants)
fn = self.module.get_or_insert_function(fn_name, self.prototype)
fn.set_is_kernel(self.is_kernel)
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
if i in self.attributes:
is_ptr = fn.args[idx].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name
arg_values.append(fn.args[idx])
idx += 1
insert_pt = self.builder.get_insert_block()
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.builder.set_insert_block(entry)
self.value_constructor.seal_block(entry)
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
self.visit_compound_statement(node.body)
return self.last_ret
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
self.builder.set_insert_block(entry)
# visit function body
self.visit_compound_statement(node.body)
# finalize function
# visit function body
has_ret = self.visit_compound_statement(node.body)
# finalize
if not has_ret:
self.builder.ret_void()
else:
self.module.reset_ret_ty(fn_name, self.last_ret.type)
# self.module.reset_ret_type(node.name)
self.builder.set_insert_block(insert_pt)
def visit_arguments(self, node):
arg_names = []
@@ -186,6 +229,12 @@ class CodeGenerator(ast.NodeVisitor):
names = [names]
if not isinstance(values, tuple):
values = [values]
if isinstance(values[0], _triton.ir.value):
struct = values[0]
ty = struct.type
if ty.is_struct():
values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)]
assert len(values) == len(names)
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
@@ -215,6 +264,17 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
mode = type(args[0])
# tuple of values -- create a struct
if len(args) > 1 and mode == triton.language.block\
and all([type(arg) == mode for arg in args]):
args = [arg.handle for arg in args]
tys = [arg.type for arg in args]
struct_ty = _triton.ir.struct_type.get(tys, True)
ret = _triton.ir.undef.get(struct_ty)
for i, arg in enumerate(args):
ret = self.builder.insert_value(ret, arg, i)
return ret
return tuple(args)
def visit_BinOp(self, node):
@@ -254,9 +314,9 @@ class CodeGenerator(ast.NodeVisitor):
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self.module.seal_block(then_bb)
self.value_constructor.seal_block(then_bb)
if else_bb:
self.module.seal_block(else_bb)
self.value_constructor.seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb)
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
@@ -271,7 +331,7 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: last statement is a terminator?
if not is_terminator:
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
self.value_constructor.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:
if isinstance(cond, triton.language.constexpr):
@@ -350,9 +410,9 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.value_constructor.seal_block(stop_bb)
self.value_constructor.seal_block(loop_bb)
self.value_constructor.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
@@ -421,9 +481,9 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: handle case where body breaks control flow
continue_fn()
stop_bb = self.builder.get_insert_block()
self.module.seal_block(stop_bb)
self.module.seal_block(loop_bb)
self.module.seal_block(next_bb)
self.value_constructor.seal_block(stop_bb)
self.value_constructor.seal_block(loop_bb)
self.value_constructor.seal_block(next_bb)
self.builder.set_insert_block(next_bb)
for stmt in node.orelse:
@@ -449,15 +509,62 @@ class CodeGenerator(ast.NodeVisitor):
for keyword in node.keywords:
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.block)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in arg_vals]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = _triton.ir.type.get_void(self.builder.context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module)
generator.visit(fn.parse())
symbol = self.module.get_function(fn_name)
ret = self.builder.call(symbol, arg_vals)
if not ret.type.is_void() and not ret.type.is_struct():
ret = triton.language.block(ret)
return ret
# built-in function
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
ret = fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args]
return fn(*args, **kws)
ret = fn(*args, **kws)
# special case: dynamic parallelism
# in this case the core primitive returns a proxy
# if isinstance(ret, triton.language.core.LaunchProxy):
# ret_type = _triton.ir.type.get_void(self.builder.context)
# arg_tys = [x.type for x in ret.args]
# prototype = _triton.ir.type.make_function(ret_type, arg_tys)
# gscope = sys.modules[ret.fn.fn.__module__].__dict__
# constants = ret.constants
# fn_name = mangle_fn(ret.fn.__name__, arg_tys, ret.constants)
# # TODO: clean-up attributes handling in function
# if not self.module.has_function(fn_name):
# attributes = {i: list(arg.parent.get_attrs(arg))[0].value for i, arg in enumerate(ret.args) \
# if isinstance(arg, _triton.ir.argument) and arg.parent.has_attr(i + 1) }
# generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, is_kernel=True)
# generator.visit(ret.fn.parse())
# symbol = self.module.get_function(fn_name)
# # TODO: should ret.args not include any constants ?
# ret = self.builder.launch(symbol, ret.args, ret.grid, ret.num_warps)
return ret
# return fn(*args, **kws)
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
@@ -669,6 +776,7 @@ class Kernel:
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
attributes = dict()
for i, arg in enumerate(wargs):
@@ -881,7 +989,7 @@ class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None):
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function
self.fn = fn
self.module = fn.__module__
@@ -890,6 +998,7 @@ class JITFunction:
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.inline = inline
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
@@ -904,6 +1013,8 @@ class JITFunction:
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -930,31 +1041,8 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, generator: CodeGenerator, **kwargs):
try:
from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.block)
else triton.language.constexpr(arg) for arg in arg_values]
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
types = generator.module.get_types().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
generator.module.set_values(values)
generator.module.set_types(types)
return ret
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel.")
# - when `.src` attribute is set, cache path needs
# to be reinitialized
@@ -1039,7 +1127,7 @@ class JITFunction:
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try:
generator.visit(self.parse())
except Exception as e:
@@ -1199,9 +1287,21 @@ def jit(*args, **kwargs):
return JITFunction(fn, **kwargs)
return decorator
######
# class ForwardDeclaration:
# def __init__(self, name, ret_ty, arg_tys) -> None:
# self.name = name
# self.ret_ty = ret_ty
# self.arg_tys = arg_tys
# def forward_declare(name, ret_ty, arg_tys):
# return ForwardDeclaration(name, ret_ty, arg_tys)
######
def cdiv(x, y):
return (x + y - 1) // y

View File

@@ -888,7 +888,7 @@ def sigmoid(x):
@triton.jit
@_add_math_1arg_docstr("softmax")
def softmax(x, ieee_rounding=False):
def softmax(x, ieee_rounding: constexpr = False):
z = x - triton.language.max(x, 0)
num = triton.language.exp(z)
den = triton.language.sum(num, 0)
@@ -942,3 +942,26 @@ def swizzle2d(i, j, size_i, size_j, size_g):
@triton.jit
def zeros_like(input):
return zeros(input.shape, input.dtype)
# -----------------------
# Dynamic Parallelism
# -----------------------
class LaunchProxy:
def __init__(self, fn, args, constants, grid, num_warps) -> None:
self.args = args
self.grid = grid
self.constants = constants
self.num_warps = num_warps
self.fn = fn
@builtin
def launch(fn, args, grid, num_warps=None, _builder=None):
constants = {i: x for i, x in enumerate(args) if isinstance(x, constexpr)}
args = [_to_ir(x, builder=_builder) for x in args if not isinstance(x, constexpr)]
grid = [_to_ir(x, builder=_builder) for x in grid]
if num_warps is None:
num_warps = _to_ir(4, builder=_builder)
return LaunchProxy(fn, args, constants, grid, num_warps)