Now vecadd works
This commit is contained in:
@@ -67,9 +67,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
print(self.lscope)
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if self.is_triton_tensor(ret):
|
||||
return self._get_tensor(name)
|
||||
return self._get_tensor(name, self.builder.get_insertion_block())
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
@@ -86,12 +87,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
||||
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
if not bb:
|
||||
bb = self.builder.get_insertion_block()
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
# param. FIXME: should delete this
|
||||
if (name, None) in self.lvalues:
|
||||
return self.lvalues[(name, None)]
|
||||
print(self.lvalues)
|
||||
assert False, f'Cannot find {name} in {bb}'
|
||||
# global value numbering
|
||||
@@ -217,10 +221,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.lscope[kwarg_names] = self.kwargs
|
||||
# initialize function
|
||||
if inline:
|
||||
pass
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
self._seal_block(entry)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -239,17 +248,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
# arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
# idx += 1
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = fn.add_entry_block()
|
||||
self._seal_block(entry)
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -821,50 +824,6 @@ class Kernel:
|
||||
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
|
||||
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
|
||||
|
||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# TODO: share code with _compile & __call__
|
||||
|
||||
# preparing args
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if i in self.fn.do_not_specialize:
|
||||
continue
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
return generator.module
|
||||
|
||||
|
||||
class Launcher:
|
||||
def __init__(self, kernel, grid):
|
||||
@@ -1209,6 +1168,53 @@ class JITFunction:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, name, asm, shared_mem, num_warps)
|
||||
|
||||
# Compile to ttir, for the propose of testing MLIR rewriting
|
||||
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
|
||||
# TODO: share code with _compile & __call__
|
||||
|
||||
# preparing args
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
# attributes
|
||||
attributes = dict()
|
||||
for i, arg in enumerate(wargs):
|
||||
if isinstance(arg, int):
|
||||
attributes[i] = Kernel.pow2_divisor(arg)
|
||||
elif i in tensor_idxs:
|
||||
addr = arg.data_ptr()
|
||||
range_size = _triton.runtime.get_pointer_range_size(addr)
|
||||
attributes[i] = min(Kernel.pow2_divisor(addr),
|
||||
Kernel.pow2_divisor(range_size))
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
|
||||
|
||||
print(f'wargs: {wargs}')
|
||||
print(f'constants: {constants}')
|
||||
print(f'arg_types: {arg_types}')
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.parse())
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||
raise e
|
||||
raise CompilationError(self.src, node) from e
|
||||
# FIXME: now we need to return context, otherwise it will be deleted
|
||||
return generator.module, context
|
||||
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
|
Reference in New Issue
Block a user