[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)

This commit is contained in:
Philippe Tillet
2022-10-21 16:52:15 -07:00
committed by GitHub
parent c4726333bf
commit bb0f9235d1
26 changed files with 683 additions and 229 deletions

View File

@@ -64,12 +64,12 @@ def mangle_ty(ty):
return 'fp32'
if ty.is_fp64():
return 'fp64'
if ty.is_void():
return 'V'
if ty.is_block():
elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S'
if ty.is_void():
return 'V'
assert False, "Unsupported type"
@@ -212,7 +212,8 @@ class CodeGenerator(ast.NodeVisitor):
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node)
# initialize function
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
visibility = "public" if self.is_kernel else "private"
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
@@ -585,6 +586,12 @@ class CodeGenerator(ast.NodeVisitor):
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
# Create placeholder for the loop induction variable
# We can use any value because the variable isn't a constexpr
# but use a distinctive value (of the right type) to ease debugging
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
self.visit(init_node)
with enter_sub_region(self) as sr:
liveins, insert_block = sr
@@ -609,13 +616,22 @@ class CodeGenerator(ast.NodeVisitor):
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
# create ForOp
self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
block.merge_block_before(for_op.get_body(0))
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
# create YieldOp
self.builder.set_insertion_point_to_end(for_op.get_body(0))
self.builder.create_yield_op([y.handle for y in yields])
if len(yields) > 0:
self.builder.create_yield_op([y.handle for y in yields])
for_op_region = for_op.get_body(0).get_parent()
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
# replace global uses with block arguments
@@ -625,8 +641,7 @@ class CodeGenerator(ast.NodeVisitor):
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.set_value(name, triton.language.core.tensor(for_op.get_result(i), yields[i].type))
for stmt in node.orelse:
assert False, "Don't know what to do with else after for"
@@ -672,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
ret_type = triton.language.void
prototype = triton.language.function_type([ret_type], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -839,18 +854,16 @@ def optimize_triton_ir(mod):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_licm_pass()
pm.run(mod)
return mod
def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_inliner_pass()
pm.add_triton_combine_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
@@ -864,6 +877,7 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.add_cse_pass()
pm.add_coalesce_pass()
pm.add_triton_gpu_combine_pass()
pm.add_licm_pass()
pm.add_triton_gpu_swizzle_pass()
pm.add_triton_gpu_combine_pass()
pm.add_cse_pass()