From 298aead378d525c361a12b54d36452f32a7c37d1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 21 Jul 2021 15:58:26 -0700 Subject: [PATCH] [FRONTEND] Fixed bugs in global symbols resolution of @triton.jit'd functions (#136) --- python/triton/code_gen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index c0743c268..ae69f876c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -659,9 +659,12 @@ class JITFunction: def __call__(self, *args, generator: CodeGenerator, **meta): try: + gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() + generator.gscope = sys.modules[self.fn.__module__].__dict__ ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) return ret