add cse pass to the pipeline & pass num-warps as an argument
This commit is contained in:
@@ -1316,9 +1316,11 @@ class JITFunction:
|
||||
pm.add_inliner_pass()
|
||||
pm.add_triton_combine_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
return pm.run(mod)
|
||||
|
Reference in New Issue
Block a user