[FRONTEND] Bunch of fixes here and there (#436)
This commit is contained in:
@@ -651,7 +651,7 @@ class Kernel:
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self.fn into code-generator object
|
||||
gscope = self.fn.fn.__globals__
|
||||
gscope = self.fn.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.fn.parse())
|
||||
@@ -723,7 +723,7 @@ class Kernel:
|
||||
pickle.dump({"binary": binary, "key": key}, f)
|
||||
os.rename(bin_cache_path + ".tmp", bin_cache_path)
|
||||
if JITFunction.cache_hook is not None:
|
||||
name = self.fn.fn.__name__
|
||||
name = self.fn.__name__
|
||||
info = key.split('-')[-3:]
|
||||
num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:]
|
||||
# make signature human-readable
|
||||
@@ -885,8 +885,6 @@ def version_key():
|
||||
ptxas_version = ''
|
||||
return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents)
|
||||
|
||||
# 3
|
||||
|
||||
|
||||
class DependenciesFinder(ast.NodeVisitor):
|
||||
|
||||
@@ -910,17 +908,14 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
func = self.visit(node.func)
|
||||
if func is None:
|
||||
return
|
||||
if isinstance(func, triton.JITFunction):
|
||||
func = func.fn
|
||||
module = inspect.getmodule(func)
|
||||
if module and module.__name__.startswith('triton.'):
|
||||
return
|
||||
if inspect.isbuiltin(func):
|
||||
return
|
||||
if not hasattr(func, 'hash'):
|
||||
src = textwrap.dedent(inspect.getsource(func))
|
||||
tree = ast.parse(src)
|
||||
finder = DependenciesFinder(func.__globals__, src)
|
||||
if func.__module__ and func.__module__.startswith('triton.'):
|
||||
return
|
||||
assert isinstance(func, triton.JITFunction)
|
||||
if func.hash is None:
|
||||
tree = ast.parse(func.src)
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
@@ -941,10 +936,12 @@ class JITFunction:
|
||||
|
||||
self.version = version
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.do_not_specialize = [] if do_not_specialize is None else\
|
||||
[self.arg_names.index(arg) for arg in do_not_specialize]
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize]
|
||||
# cache for callable driver objects (e.g. CUkernel)
|
||||
self.bin_cache = dict()
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
@@ -954,15 +951,19 @@ class JITFunction:
|
||||
self.__annotations__ = fn.__annotations__
|
||||
# forward docs
|
||||
self.__doc__ = fn.__doc__
|
||||
self.__name__ = fn.__name__
|
||||
self.__globals__ = fn.__globals__
|
||||
self.__module__ = fn.__module__
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def cache_key(self):
|
||||
if not hasattr(self.fn, 'hash'):
|
||||
dependencies_finder = DependenciesFinder(globals=self.fn.__globals__, src=self.src)
|
||||
# TODO : hash should be attribute of `self`
|
||||
if self.hash is None:
|
||||
dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src)
|
||||
dependencies_finder.visit(self.parse())
|
||||
self.fn.hash = dependencies_finder.ret + version_key()
|
||||
return self.fn.hash
|
||||
self.hash = dependencies_finder.ret + version_key()
|
||||
return self.hash
|
||||
|
||||
# we do not parse `src` in the constructor because
|
||||
# the user might want to monkey-patch self.src dynamically.
|
||||
@@ -974,14 +975,20 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
def __call__(self, *args, generator: CodeGenerator):
|
||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||
try:
|
||||
from inspect import getcallargs
|
||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||
arg_values = [arg_values[name] for name in self.arg_names]
|
||||
arg_values = [arg if isinstance(arg, triton.language.block)
|
||||
else triton.language.constexpr(arg) for arg in arg_values]
|
||||
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
values = generator.module.get_values().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
generator.module.set_values(values)
|
||||
@@ -1001,8 +1008,7 @@ class JITFunction:
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
if name == 'src':
|
||||
if hasattr(self.fn, 'hash'):
|
||||
delattr(self.fn, 'hash')
|
||||
self.hash = None
|
||||
JITFunction.cache_key.fget.cache_clear()
|
||||
|
||||
def _init_kernel(self):
|
||||
|
Reference in New Issue
Block a user