Now vecadd works

This commit is contained in:
Yan Da
2022-03-30 20:21:47 +08:00
parent e381dc72c5
commit 2041b67fbf
5 changed files with 285 additions and 386 deletions

View File

@@ -67,9 +67,10 @@ class CodeGenerator(ast.NodeVisitor):
elif name in self.builtins:
ret = self.builtins[name]
else:
print(self.lscope)
raise ValueError(f'{name} is not defined')
if self.is_triton_tensor(ret):
return self._get_tensor(name)
return self._get_tensor(name, self.builder.get_insertion_block())
return ret
def set_value(self, name: str,
@@ -86,12 +87,15 @@ class CodeGenerator(ast.NodeVisitor):
#
# SSA-construction
#
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
if not bb:
bb = self.builder.get_insertion_block()
# local value numbering
if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)]
# param. FIXME: should delete this
if (name, None) in self.lvalues:
return self.lvalues[(name, None)]
print(self.lvalues)
assert False, f'Cannot find {name} in {bb}'
# global value numbering
@@ -217,10 +221,15 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope[kwarg_names] = self.kwargs
# initialize function
if inline:
pass
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.visit_compound_statement(node.body)
return self.last_ret
else:
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
self.module.push_back(fn)
entry = fn.add_entry_block()
self._seal_block(entry)
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
@@ -239,17 +248,11 @@ class CodeGenerator(ast.NodeVisitor):
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
# arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
# idx += 1
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
self.visit_compound_statement(node.body)
return self.last_ret
else:
entry = fn.add_entry_block()
self._seal_block(entry)
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
self.visit_compound_statement(node.body)
@@ -821,50 +824,6 @@ class Kernel:
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
# Compile to ttir, for the propose of testing MLIR rewriting
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# TODO: share code with _compile & __call__
# preparing args
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
attributes = dict()
for i, arg in enumerate(wargs):
if i in self.fn.do_not_specialize:
continue
if isinstance(arg, int):
attributes[i] = Kernel.pow2_divisor(arg)
elif i in tensor_idxs:
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
# create IR module
context = _triton.ir.context()
context.load_triton()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
return generator.module
class Launcher:
def __init__(self, kernel, grid):
@@ -1209,6 +1168,53 @@ class JITFunction:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps)
# Compile to ttir, for the propose of testing MLIR rewriting
def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# TODO: share code with _compile & __call__
# preparing args
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
attributes = dict()
for i, arg in enumerate(wargs):
if isinstance(arg, int):
attributes[i] = Kernel.pow2_divisor(arg)
elif i in tensor_idxs:
addr = arg.data_ptr()
range_size = _triton.runtime.get_pointer_range_size(addr)
attributes[i] = min(Kernel.pow2_divisor(addr),
Kernel.pow2_divisor(range_size))
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants]
print(f'wargs: {wargs}')
print(f'constants: {constants}')
print(f'arg_types: {arg_types}')
# create IR module
context = _triton.ir.context()
context.load_triton()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
# generate Triton-IR
# export symbols visible from self into code-generator object
gscope = self.__globals__
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
try:
generator.visit(self.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
raise e
raise CompilationError(self.src, node) from e
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
def __getitem__(self, grid):
return Launcher(self._init_kernel(), grid)