[OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790)
This commit is contained in:
@@ -64,12 +64,12 @@ def mangle_ty(ty):
|
||||
return 'fp32'
|
||||
if ty.is_fp64():
|
||||
return 'fp64'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
if ty.is_block():
|
||||
elt = mangle_ty(ty.scalar)
|
||||
shape = '_'.join(map(str, ty.shape))
|
||||
return f'{elt}S{shape}S'
|
||||
if ty.is_void():
|
||||
return 'V'
|
||||
assert False, "Unsupported type"
|
||||
|
||||
|
||||
@@ -212,7 +212,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder))
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -585,6 +586,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
# Create placeholder for the loop induction variable
|
||||
# We can use any value because the variable isn't a constexpr
|
||||
# but use a distinctive value (of the right type) to ease debugging
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
init_node = ast.Assign(targets=[st_target], value=ast.Num(value=0xBADF00D))
|
||||
self.visit(init_node)
|
||||
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
@@ -609,13 +616,22 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
if len(yields) > 0:
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
@@ -625,8 +641,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.set_value(name, triton.language.core.tensor(for_op.get_result(i), yields[i].type))
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False, "Don't know what to do with else after for"
|
||||
@@ -672,7 +687,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types)
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -839,18 +854,16 @@ def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -864,6 +877,7 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_cse_pass()
|
||||
pm.add_coalesce_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_licm_pass()
|
||||
pm.add_triton_gpu_swizzle_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_cse_pass()
|
||||
|
Reference in New Issue
Block a user