[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)

* update membar pass when data is double buffered

* Add instruction prefetch_s

* prefetch tests pass (except the 1 warp case)

* Fix the 1-warp bug

* Add back prefetch files

* Disable prefetch on a100

* Always add war barrier on sm>=80
This commit is contained in:
daadaada
2021-05-13 10:42:18 +08:00
committed by Philippe Tillet
parent 147675923e
commit 967e629c0c
14 changed files with 408 additions and 42 deletions

View File

@@ -11,6 +11,7 @@
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
@@ -169,6 +170,7 @@ public:
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
@@ -209,16 +211,19 @@ private:
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
@@ -227,6 +232,11 @@ private:
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
};
}

View File

@@ -5,6 +5,7 @@
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
@@ -44,14 +45,16 @@ private:
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
target* tgt_;
};

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
// forward dclaration
namespace triton::ir{
class module;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
};
}
#endif

View File

@@ -151,6 +151,7 @@ public:
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
value *create_prefetch_s(value *arg, int inc);
private:
context &ctx_;

View File

@@ -143,7 +143,8 @@ enum value_id_t: unsigned {
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE
INST_MAKE_RANGE,
INST_PREFETCH_S,
};

View File

@@ -666,6 +666,11 @@ private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
@@ -821,6 +826,23 @@ private:
int N_;
};
class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
int get_inc() const { return inc_; }
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
instruction *next=nullptr);
};
//// On NVIDIA, implementation is such that
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
//// so as to enable re-association on nv_static_program_idx which is constant

View File

@@ -70,6 +70,7 @@ class barrier_inst;
class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class make_range_sta;
class undef_value;
@@ -146,6 +147,7 @@ public:
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;