[FRONTEND] Fixed bugs in global symbols resolution of @triton.jit'd functions (#136)

This commit is contained in:
Philippe Tillet
2021-07-21 15:58:26 -07:00
committed by Philippe Tillet
parent 94ce6aa80f
commit 298aead378

View File

@@ -659,9 +659,12 @@ class JITFunction:
def __call__(self, *args, generator: CodeGenerator, **meta): def __call__(self, *args, generator: CodeGenerator, **meta):
try: try:
gscope = generator.gscope.copy()
lscope = generator.lscope.copy() lscope = generator.lscope.copy()
values = generator.module.get_values().copy() values = generator.module.get_values().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
generator.gscope = gscope
generator.lscope = lscope generator.lscope = lscope
generator.module.set_values(values) generator.module.set_values(values)
return ret return ret