[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)

This commit is contained in:
Philippe Tillet
2021-11-29 19:11:26 -08:00
committed by GitHub
parent 1296eb877b
commit c86ad9c9ab
4 changed files with 149 additions and 122 deletions

View File

@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
# initialize defaults
for i, default_value in enumerate(node.args.defaults):
arg_node = node.args.args[-i-1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# store keyword arguments in local scope
self.lscope[kwarg_names] = self.kwargs
# initialize function
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
@@ -483,6 +508,7 @@ class CompilationError(Exception):
self.message += '\n' + ' ' * node.col_offset + '^'
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
self.args = (src, node, err)
class OutOfResources(Exception):
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
f'Required: {required}'\
f'Hardware limit: {limit}'
super().__init__(self.message)
self.args = (required, limit, name)
class Kernel:
@@ -805,7 +832,10 @@ class JITFunction:
# information of wrapped function
self.fn = fn
self.module = fn.__module__
self.arg_names = inspect.getfullargspec(fn).args
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.arg_defaults = [v.default for v in signature.parameters.values()]
self.version = version
self.src = textwrap.dedent(inspect.getsource(fn))
self.do_not_specialize = [] if do_not_specialize is None else\
@@ -829,7 +859,7 @@ class JITFunction:
if not hasattr(self.fn, 'hash'):
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.fn.hash = dependencies_finder.ret
self.fn.hash = dependencies_finder.ret + version_key()
return self.fn.hash
# we do not parse `src` in the constructor because
@@ -848,6 +878,7 @@ class JITFunction:
lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
generator.gscope = gscope
generator.lscope = lscope