[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)
* update membar pass when data is double buffered * Add instruction prefetch_s * prefetch tests pass (except the 1 warp case) * Fix the 1-warp bug * Add back prefetch files * Disable prefetch on a100 * Always add war barrier on sm>=80
This commit is contained in:
committed by
Philippe Tillet
parent
147675923e
commit
967e629c0c
@@ -12,6 +12,7 @@
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
@@ -44,11 +45,12 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get());
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||
// codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target.get());
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
@@ -90,8 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
mod = driver::module::create(dev, std::move(llvm));
|
||||
ker = driver::kernel::create(&*mod, name.c_str());
|
||||
|
Reference in New Issue
Block a user