Device function & PassManager

This commit is contained in:
Yan Da
2022-04-15 14:41:57 +08:00
parent 44d75cf9bb
commit 1c52bd587d
5 changed files with 464 additions and 92 deletions

View File

@@ -8,6 +8,10 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
@@ -717,6 +721,9 @@ void init_triton_ir(py::module &&m) {
.def("set_attr", [](mlir::OpState &self, std::string &name, mlir::Attribute &attr) -> void {
self->setAttr(name, attr);
})
.def("get_num_results", [](mlir::OpState &self) -> unsigned {
return self->getNumResults();
})
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
return self->getResult(idx);
})
@@ -755,12 +762,18 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module")
.def("dump", [](mlir::ModuleOp &self) -> void {
self.dump();
})
.def("dump", &mlir::ModuleOp::dump)
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
})
;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
@@ -772,6 +785,7 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -784,11 +798,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::ModuleOp>(loc);
})
// control flow
.def("ret_void", [](mlir::OpBuilder &self) {
.def("ret", [](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::ReturnOp>(loc);
}, ret::reference)
self.create<mlir::ReturnOp>(loc, vals);
})
.def("call", [](mlir::OpBuilder &self, mlir::FuncOp &func, std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::CallOp>(loc, func, args);
})
// insertion block/point
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
self.setInsertionPointToStart(&block);
@@ -900,6 +917,16 @@ void init_triton_ir(py::module &&m) {
}
throw std::runtime_error("invalid function type");
})
.def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType) -> mlir::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
return self.create<mlir::FuncOp>(loc, funcName, funcTy);
}
throw std::runtime_error("invalid function type");
})
.def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* {
mlir::Region *parent = self.getBlock()->getParent();
return self.createBlock(parent);
@@ -1293,6 +1320,16 @@ void init_triton_ir(py::module &&m) {
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
// .def("create_barrier", &ir::builder::create_barrier, ret::reference);
;
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
self.run(mod.getOperation());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})
;
}
void init_triton(py::module &m) {

View File

@@ -14,6 +14,7 @@ import textwrap
import time
import warnings
from typing import Dict, Optional, Set, Tuple, Union
from numpy import isin
import torch
from filelock import FileLock
@@ -22,6 +23,41 @@ import triton
import triton._C.libtriton.triton as _triton
from .tools.disasm import extract
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
return 'fp16'
if ty.is_bf16():
return 'bf16'
if ty.is_fp32():
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
assert False, "Unsupport type"
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
@@ -40,15 +76,16 @@ class enter_sub_region:
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
self.module = self.builder.create_module()
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = function_types
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.kwargs = kwargs
self.is_kernel = is_kernel
self.last_node = None
self.builtins = {
'range': range,
@@ -104,7 +141,7 @@ class CodeGenerator(ast.NodeVisitor):
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
@@ -120,12 +157,18 @@ class CodeGenerator(ast.NodeVisitor):
# By design, only non-kernel functions can return
def visit_Return(self, node):
ret = self.visit(node.value)
if ret is None:
return self.builder.ret_void()
return ret
ret_value = self.visit(node.value)
if ret_value is None:
return self.builder.ret([])
if isinstance(ret_value, list):
assert False, "returing a tuple is not supported"
else:
ret = triton.language.core._to_tensor(ret_value, self.builder)
ret_type = ret.type
self.builder.ret([ret_value.handle])
return ret_type
def visit_FunctionDef(self, node, inline=False, arg_values=None):
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
@@ -138,46 +181,51 @@ class CodeGenerator(ast.NodeVisitor):
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.visit_compound_statement(node.body)
return self.last_ret
else:
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
if not isinstance(cst, triton.language.constexpr):
cst = triton.language.constexpr(self.constants[i])
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
self.visit_compound_statement(node.body)
# finalize function
self.builder.ret_void()
insert_pt = self.builder.get_insertion_block()
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
self.builder.ret([])
else:
# update return type
if isinstance(self.last_ret_type, tuple):
self.prototype.ret_types = [ret_tensor.type for ret_tensor in self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
else:
self.prototype.ret_types = [self.last_ret_type]
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
def visit_arguments(self, node):
arg_names = []
@@ -219,7 +267,6 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple):
values = [values]
for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr):
value = value.value
@@ -581,7 +628,40 @@ class CodeGenerator(ast.NodeVisitor):
kws.update(self.visit(keyword))
args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
constants = {i: args[i] for i in constexprs}
# generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary
if not self.module.has_function(fn_name):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
call_op = self.builder.call(symbol, arg_vals)
if call_op.get_num_results() == 0:
return None
elif call_op.get_num_results() == 1:
return triton.language.tensor(call_op.get_result(0), callee_ret_type)
else:
# should return a tuple of tl.tensor
raise RuntimeError("Not implemented")
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws)
@@ -1012,7 +1092,7 @@ class JITFunction:
cache_hook = None
def __init__(self, fn, version=None, do_not_specialize=None):
def __init__(self, fn, version=None, inline=True, do_not_specialize=None):
# information of wrapped function
self.fn = fn
self.module = fn.__module__
@@ -1021,6 +1101,7 @@ class JITFunction:
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.inline = inline
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
@@ -1035,6 +1116,8 @@ class JITFunction:
# annotations
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
# forward docs
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
@@ -1061,32 +1144,8 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
# Called by CodeGenerator.visit_Call()
def __call__(self, *args, generator: CodeGenerator, **kwargs):
try:
from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in arg_values]
# Record values in the caller (parent scope)
gscope = generator.gscope.copy()
lscope = generator.lscope.copy()
# TODO: clear values other than args
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope
generator.lscope = lscope
return ret
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
# - when `.src` attribute is set, cache path needs
# to be reinitialized
@@ -1172,7 +1231,7 @@ class JITFunction:
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try:
generator.visit(self.parse())
except Exception as e:
@@ -1232,7 +1291,7 @@ class JITFunction:
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True)
try:
generator.visit(self.parse())
except Exception as e: