[CODEGEN] Pipeline fixup (#336)
This commit is contained in:
@@ -101,6 +101,15 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir:
|
||||
}
|
||||
}
|
||||
|
||||
struct pipeline_info_t {
|
||||
ir::load_inst* load;
|
||||
ir::phi_node* ptr;
|
||||
ir::dot_inst* dot;
|
||||
|
||||
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
|
||||
: load(load), ptr(ptr), dot(dot) {}
|
||||
};
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
@@ -108,14 +117,15 @@ void pipeline::run(ir::module &mod) {
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
||||
std::vector<pipeline_info_t> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
||||
to_pipeline.push_back({load, ptr});
|
||||
&& users.size() == 1 && dot)
|
||||
to_pipeline.push_back({load, ptr, dot});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
@@ -123,8 +133,8 @@ void pipeline::run(ir::module &mod) {
|
||||
const int num_stages = num_stages_;
|
||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.first;
|
||||
ir::phi_node* ptr = info.second;
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
@@ -288,22 +298,23 @@ void pipeline::run(ir::module &mod) {
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::map<ir::basic_block*, move_config_t> to_move;
|
||||
std::vector<move_config_t> to_move(to_pipeline.size());
|
||||
|
||||
if(has_copy_async_){
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
for(ir::basic_block* bb: fn->blocks())
|
||||
for(ir::instruction* inst: bb->get_inst_list()){
|
||||
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
||||
recursive_deps(i, bb, to_move[bb].insts);
|
||||
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
||||
to_move[bb].dst = i;
|
||||
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
|
||||
auto info = to_pipeline[idx];
|
||||
ir::load_inst* load = info.load;
|
||||
ir::phi_node* ptr = info.ptr;
|
||||
ir::dot_inst* dot = info.dot;
|
||||
ir::basic_block* bb = dot->get_parent();
|
||||
recursive_deps(dot, bb, to_move[idx].insts);
|
||||
to_move[idx].dst = load;
|
||||
}
|
||||
|
||||
for(auto& x: to_move){
|
||||
builder.set_insert_point_after(x.second.dst);
|
||||
for(ir::instruction* i: x.second.insts){
|
||||
x.first->erase(i);
|
||||
for(auto& move_config: to_move){
|
||||
builder.set_insert_point_after(move_config.dst);
|
||||
for(ir::instruction* i: move_config.insts){
|
||||
i->get_parent()->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
|
@@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'):
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out, **meta):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
|
Reference in New Issue
Block a user