[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -749,8 +749,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
# visit kernel AST
|
||||
@@ -769,6 +770,14 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
return ret
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
@@ -785,7 +794,7 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
# pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -815,6 +824,7 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
# tritongpu-ir
|
||||
|
Reference in New Issue
Block a user