Init commit
This commit is contained in:
@@ -906,12 +906,7 @@ void init_triton_ir(py::module &&m) {
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
|
||||
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
|
||||
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
|
||||
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
|
||||
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
@@ -819,6 +819,49 @@ class Kernel:
|
||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# TODO: share code with _compile & __call__
|
||||
|
||||
# preparing args
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in self.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
return generator.module
|
||||
|
||||
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
|
Reference in New Issue
Block a user