diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8622333bf..90018028d 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -305,9 +305,6 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) - def visit_Str(self, node): - return ast.literal_eval(node) - def visit_Subscript(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) @@ -374,9 +371,6 @@ class CodeGenerator(ast.NodeVisitor): def visit_Index(self, node): return self.visit(node.value) - def visit_NameConstant(self, node): - return node.value - def visit_keyword(self, node): return {node.arg: self.visit(node.value)} @@ -393,8 +387,18 @@ class CodeGenerator(ast.NodeVisitor): return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) - def visit_Num(self, node): - return node.n + def visit_Constant(self, node): + return node.value + + if sys.version_info < (3, 8): + def visit_NameConstant(self, node): + return node.value + + def visit_Num(self, node): + return node.n + + def visit_Str(self, node): + return ast.literal_eval(node) def visit_Attribute(self, node): lhs = self.visit(node.value)