diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 6a301fd1a..92dc5c8cc 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -146,7 +146,7 @@ class CodeGenerator(ast.NodeVisitor): values = [values] for name, value in zip(names, values): if not isinstance(value, triton.language.block): - value = triton.language._to_ir(value, self.builder) + value = triton.language.core._to_ir(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -383,7 +383,7 @@ class CodeGenerator(ast.NodeVisitor): if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language: + sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) return fn(*args, **kws) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py new file mode 100644 index 000000000..ab1daeb41 --- /dev/null +++ b/python/triton/language/__init__.py @@ -0,0 +1,2 @@ +from . import core +from .core import * \ No newline at end of file diff --git a/python/triton/language.py b/python/triton/language/core.py similarity index 99% rename from python/triton/language.py rename to python/triton/language/core.py index 6b0100274..ff243bc5e 100644 --- a/python/triton/language.py +++ b/python/triton/language/core.py @@ -25,7 +25,7 @@ def _patch(fn): if x.type.is_void(): return None return block(x) - return tl + return x def wrapper(*args, **kwargs): builder = args[-1]