Device function & PassManager
This commit is contained in:
@@ -8,6 +8,10 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
||||
#include "triton/ir/Dialect.h"
|
||||
#include "triton/ir/Types.h"
|
||||
|
||||
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
|
||||
self->setAttr(name, attr);
|
||||
})
|
||||
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
|
||||
return self->getNumResults();
|
||||
})
|
||||
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
|
||||
return self->getResult(idx);
|
||||
})
|
||||
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
||||
|
||||
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
|
||||
.def("dump", [](mlir::ModuleOp &self) -> void {
|
||||
self.dump();
|
||||
})
|
||||
.def("dump", &mlir::ModuleOp::dump)
|
||||
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
||||
self.push_back(funcOp);
|
||||
})
|
||||
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
|
||||
if (self.lookupSymbol(funcName))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
|
||||
return self.lookupSymbol<mlir::FuncOp>(funcName);
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
|
||||
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||
return self.addEntryBlock();
|
||||
}, ret::reference)
|
||||
.def("reset_type", &mlir::FuncOp::setType)
|
||||
;
|
||||
|
||||
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::ModuleOp>(loc);
|
||||
})
|
||||
// control flow
|
||||
.def("ret_void", [](mlir::OpBuilder &self) {
|
||||
.def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::ReturnOp>(loc);
|
||||
}, ret::reference)
|
||||
self.create<mlir::ReturnOp>(loc, vals);
|
||||
})
|
||||
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::CallOp>(loc, func, args);
|
||||
})
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
||||
self.setInsertionPointToStart(&block);
|
||||
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
})
|
||||
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
||||
mlir::Region *parent = self.getBlock()->getParent();
|
||||
return self.createBlock(parent);
|
||||
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
// .def("create_barrier", &ir::builder::create_barrier, ret::reference);
|
||||
;
|
||||
|
||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||
.def(py::init<mlir::MLIRContext *>())
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
||||
self.run(mod.getOperation());
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
;
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
|
@@ -14,6 +14,7 @@ import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict, Optional, Set, Tuple, Union
|
||||
from numpy import isin
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -22,6 +23,41 @@ import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from .tools.disasm import extract
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
if ty.is_int():
|
||||
return 'i' + str(ty.int_bitwidth)
|
||||
if ty.is_fp8():
|
||||
return 'fp8'
|
||||
if ty.is_fp16():
|
||||
return 'fp16'
|
||||
if ty.is_bf16():
|
||||
return 'bf16'
|
||||
if ty.is_fp32():
|
||||
return 'fp32'
|
||||
if ty.is_fp64():
|
||||
return 'fp64'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
if ty.is_block():
|
||||
elt = mangle_ty(ty.scalar)
|
||||
shape = '_'.join(map(str, ty.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
assert False, "Unsupport type"
|
||||
|
||||
|
||||
def mangle_fn(name, arg_tys, constants):
|
||||
# doesn't mangle ret type, which must be a function of arg tys
|
||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
@@ -40,15 +76,16 @@ class enter_sub_region:
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module()
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
@@ -104,7 +141,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
self.last_ret_type = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
@@ -120,12 +157,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
# By design, only non-kernel functions can return
|
||||
def visit_Return(self, node):
|
||||
ret = self.visit(node.value)
|
||||
if ret is None:
|
||||
return self.builder.ret_void()
|
||||
return ret
|
||||
ret_value = self.visit(node.value)
|
||||
if ret_value is None:
|
||||
return self.builder.ret([])
|
||||
if isinstance(ret_value, list):
|
||||
assert False, "returing a tuple is not supported"
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||
ret_type = ret.type
|
||||
self.builder.ret([ret_value.handle])
|
||||
return ret_type
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
def visit_FunctionDef(self, node):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
@@ -138,46 +181,51 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# store keyword arguments in local scope
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
if inline:
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
pass
|
||||
# TODO: ...
|
||||
# if i in self.attributes:
|
||||
# is_ptr = fn.args[idx].type.is_ptr()
|
||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
|
||||
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
if not isinstance(cst, triton.language.constexpr):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
pass
|
||||
# TODO: ...
|
||||
# if i in self.attributes:
|
||||
# is_ptr = fn.args[idx].type.is_ptr()
|
||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
self.builder.ret_void()
|
||||
insert_pt = self.builder.get_insertion_block()
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
has_ret = self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
if not has_ret:
|
||||
self.builder.ret([])
|
||||
else:
|
||||
# update return type
|
||||
if isinstance(self.last_ret_type, tuple):
|
||||
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
else:
|
||||
self.prototype.ret_types = [self.last_ret_type]
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
if insert_pt:
|
||||
self.builder.set_insertion_point_to_end(insert_pt)
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
@@ -219,7 +267,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
# TODO: can we store constexpr here to support constant folding?
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
@@ -581,7 +628,40 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
kws.update(self.visit(keyword))
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
constants = {i: args[i] for i in constexprs}
|
||||
# generate call
|
||||
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
|
||||
arg_vals = [arg.handle for arg in args if arg is not None]
|
||||
arg_types = [arg.type for arg in args if arg is not None]
|
||||
fn_name = mangle_fn(fn.__name__, arg_types, constants)
|
||||
# generate function def if necessary
|
||||
if not self.module.has_function(fn_name):
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
else:
|
||||
callee_ret_type = self.function_ret_types[fn_name]
|
||||
symbol = self.module.get_function(fn_name)
|
||||
call_op = self.builder.call(symbol, arg_vals)
|
||||
if call_op.get_num_results() == 0:
|
||||
return None
|
||||
elif call_op.get_num_results() == 1:
|
||||
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
|
||||
else:
|
||||
# should return a tuple of tl.tensor
|
||||
raise RuntimeError("Not implemented")
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
@@ -1012,7 +1092,7 @@ class JITFunction:
|
||||
|
||||
cache_hook = None
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
@@ -1021,6 +1101,7 @@ class JITFunction:
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.inline = inline
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
@@ -1035,6 +1116,8 @@ class JITFunction:
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# constexprs
|
||||
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
@@ -1061,32 +1144,8 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
# Called by CodeGenerator.visit_Call()
|
||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||
try:
|
||||
from inspect import getcallargs
|
||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||
arg_values = [arg_values[name] for name in self.arg_names]
|
||||
arg_values = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in arg_values]
|
||||
|
||||
# Record values in the caller (parent scope)
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
|
||||
# TODO: clear values other than args
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
||||
|
||||
# - when `.src` attribute is set, cache path needs
|
||||
# to be reinitialized
|
||||
@@ -1172,7 +1231,7 @@ class JITFunction:
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
@@ -1232,7 +1291,7 @@ class JITFunction:
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user