[PYTHON] Renamed triton.core -> triton.language (#92)
This commit is contained in:
committed by
Philippe Tillet
parent
41410012e8
commit
bfc0a7587d
@@ -26,21 +26,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.block):
|
||||
if isinstance(ret, triton.language.block):
|
||||
handle = self.module.get_value(name)
|
||||
return triton.block(handle)
|
||||
return triton.language.block(handle)
|
||||
return ret
|
||||
|
||||
def set_value(self, name, value):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.block(value)
|
||||
if isinstance(value, triton.block):
|
||||
value = triton.language.block(value)
|
||||
if isinstance(value, triton.language.block):
|
||||
self.module.set_value(name, value.handle)
|
||||
self.module.scope.set_type(name, value.handle.type)
|
||||
self.lscope[name] = value
|
||||
|
||||
def is_triton_object(self, value):
|
||||
return isinstance(value, triton.block)
|
||||
return isinstance(value, triton.language.block)
|
||||
|
||||
def visit_compound_statement(self, stmts, add_scope=False):
|
||||
if add_scope:
|
||||
@@ -63,7 +63,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
self.last_node = None
|
||||
self.builtins = {'range': range, 'min': triton.minimum, 'float': float, 'int': int, 'print': print, 'getattr': getattr}
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': print,
|
||||
'getattr': getattr,
|
||||
}
|
||||
|
||||
def visit_Module(self, node):
|
||||
self.module.add_new_scope()
|
||||
@@ -303,7 +310,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
||||
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.where(self.visit(pos_step_node),\
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
builder=self.builder)
|
||||
@@ -359,7 +366,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.core:
|
||||
sys.modules[fn.__module__] is triton.language:
|
||||
return fn(*args, builder=self.builder, **kws)
|
||||
return fn(*args, **kws)
|
||||
|
||||
@@ -613,6 +620,11 @@ class JITFunction:
|
||||
raise e
|
||||
raise CompilationError(self.src, node, e)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == 'kernel_decorators':
|
||||
self.kernel = None
|
||||
super(JITFunction, self).__setattr__(name, value)
|
||||
|
||||
def _init_kernel(self):
|
||||
if self.kernel is None:
|
||||
self.kernel = Kernel(self)
|
||||
@@ -659,4 +671,23 @@ def heuristics(values):
|
||||
|
||||
|
||||
def jit(fn):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
|
||||
:note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method.
|
||||
|
||||
:note: This function will be compiled and run on the GPU. It will only have access to:
|
||||
|
||||
* python primitives,
|
||||
* objects within the triton.language package,
|
||||
* arguments to this function,
|
||||
* other jit'd functions
|
||||
|
||||
:param fn: the function to be jit-compiled
|
||||
:type fn: Callable
|
||||
"""
|
||||
return JITFunction(fn)
|
||||
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
Reference in New Issue
Block a user