[FRONTEND] Fixed bugs in global symbols resolution of @triton.jit'd functions (#136)
This commit is contained in:
committed by
Philippe Tillet
parent
94ce6aa80f
commit
298aead378
@@ -659,9 +659,12 @@ class JITFunction:
|
|||||||
|
|
||||||
def __call__(self, *args, generator: CodeGenerator, **meta):
|
def __call__(self, *args, generator: CodeGenerator, **meta):
|
||||||
try:
|
try:
|
||||||
|
gscope = generator.gscope.copy()
|
||||||
lscope = generator.lscope.copy()
|
lscope = generator.lscope.copy()
|
||||||
values = generator.module.get_values().copy()
|
values = generator.module.get_values().copy()
|
||||||
|
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||||
|
generator.gscope = gscope
|
||||||
generator.lscope = lscope
|
generator.lscope = lscope
|
||||||
generator.module.set_values(values)
|
generator.module.set_values(values)
|
||||||
return ret
|
return ret
|
||||||
|
Reference in New Issue
Block a user