[PYTHON] Renamed triton.core -> triton.language (#92)

This commit is contained in:
Philippe Tillet
2021-04-23 17:18:14 -04:00
committed by Philippe Tillet
parent 41410012e8
commit bfc0a7587d
19 changed files with 355 additions and 243 deletions

View File

@@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor):
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.block):
if isinstance(ret, triton.language.block):
handle = self.module.get_value(name)
return triton.block(handle)
return triton.language.block(handle)
return ret
def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.block(value)
if isinstance(value, triton.block):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.scope.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
return isinstance(value, triton.block)
return isinstance(value, triton.language.block)
def visit_compound_statement(self, stmts, add_scope=False):
if add_scope:
@@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor):
self.constants = constants
self.kwargs = kwargs
self.last_node = None
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'getattr': getattr,
}
def visit_Module(self, node):
self.module.add_new_scope()
@@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor):
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.where(self.visit(pos_step_node),\
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
@@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor):
if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.core:
sys.modules[fn.__module__] is triton.language:
return fn(*args, builder=self.builder, **kws)
return fn(*args, **kws)
@@ -613,6 +620,11 @@ class JITFunction:
raise e
raise CompilationError(self.src, node, e)
def __setattr__(self, name, value):
if name == 'kernel_decorators':
self.kernel = None
super(JITFunction, self).__setattr__(name, value)
def _init_kernel(self):
if self.kernel is None:
self.kernel = Kernel(self)
@@ -659,4 +671,23 @@ def heuristics(values):
def jit(fn):
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* objects within the triton.language package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
return JITFunction(fn)
def cdiv(x, y):
return (x + y - 1) // y