[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
@@ -93,6 +93,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
# initialize defaults
|
||||
for i, default_value in enumerate(node.args.defaults):
|
||||
arg_node = node.args.args[-i-1]
|
||||
annotation = arg_node.annotation
|
||||
name = arg_node.arg
|
||||
st_target = ast.Name(id=name, ctx=ast.Store())
|
||||
if annotation is None:
|
||||
init_node = ast.Assign(targets=[st_target], value=default_value)
|
||||
else:
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# store keyword arguments in local scope
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
@@ -353,6 +364,20 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
for i in iterator(*iter_args):
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# create nodes
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||
@@ -483,6 +508,7 @@ class CompilationError(Exception):
|
||||
self.message += '\n' + ' ' * node.col_offset + '^'
|
||||
self.message += '\n Error: ' + str(err)
|
||||
super().__init__(self.message)
|
||||
self.args = (src, node, err)
|
||||
|
||||
|
||||
class OutOfResources(Exception):
|
||||
@@ -491,6 +517,7 @@ class OutOfResources(Exception):
|
||||
f'Required: {required}'\
|
||||
f'Hardware limit: {limit}'
|
||||
super().__init__(self.message)
|
||||
self.args = (required, limit, name)
|
||||
|
||||
|
||||
class Kernel:
|
||||
@@ -805,7 +832,10 @@ class JITFunction:
|
||||
# information of wrapped function
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.arg_names = inspect.getfullargspec(fn).args
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.arg_defaults = [v.default for v in signature.parameters.values()]
|
||||
|
||||
self.version = version
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||
@@ -829,7 +859,7 @@ class JITFunction:
|
||||
if not hasattr(self.fn, 'hash'):
|
||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.fn.hash = dependencies_finder.ret
|
||||
self.fn.hash = dependencies_finder.ret + version_key()
|
||||
return self.fn.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
@@ -848,6 +878,7 @@ class JITFunction:
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
|
Reference in New Issue
Block a user