[CODEGEN] Pipeline fixup (#336)
This commit is contained in:
@@ -101,6 +101,15 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct pipeline_info_t {
|
||||||
|
ir::load_inst* load;
|
||||||
|
ir::phi_node* ptr;
|
||||||
|
ir::dot_inst* dot;
|
||||||
|
|
||||||
|
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
|
||||||
|
: load(load), ptr(ptr), dot(dot) {}
|
||||||
|
};
|
||||||
|
|
||||||
void pipeline::run(ir::module &mod) {
|
void pipeline::run(ir::module &mod) {
|
||||||
// *Very* conservative heuristics for pre-fetching.
|
// *Very* conservative heuristics for pre-fetching.
|
||||||
// A load instruction can be pipelined if:
|
// A load instruction can be pipelined if:
|
||||||
@@ -108,14 +117,15 @@ void pipeline::run(ir::module &mod) {
|
|||||||
// in its basic block (i.e., pointer induction variable)
|
// in its basic block (i.e., pointer induction variable)
|
||||||
// - the load has only a single use in a dot instruction
|
// - the load has only a single use in a dot instruction
|
||||||
// As more use cases become apparent, this pass will be improved
|
// As more use cases become apparent, this pass will be improved
|
||||||
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
std::vector<pipeline_info_t> to_pipeline;
|
||||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||||
auto users = load->get_users();
|
auto users = load->get_users();
|
||||||
|
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
|
||||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||||
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
&& users.size() == 1 && dot)
|
||||||
to_pipeline.push_back({load, ptr});
|
to_pipeline.push_back({load, ptr, dot});
|
||||||
}});
|
}});
|
||||||
// do the pipelining
|
// do the pipelining
|
||||||
std::vector<ir::phi_node*> new_loads;
|
std::vector<ir::phi_node*> new_loads;
|
||||||
@@ -123,8 +133,8 @@ void pipeline::run(ir::module &mod) {
|
|||||||
const int num_stages = num_stages_;
|
const int num_stages = num_stages_;
|
||||||
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||||
for(auto info: to_pipeline){
|
for(auto info: to_pipeline){
|
||||||
ir::load_inst* load = info.first;
|
ir::load_inst* load = info.load;
|
||||||
ir::phi_node* ptr = info.second;
|
ir::phi_node* ptr = info.ptr;
|
||||||
ir::basic_block* block = load->get_parent();
|
ir::basic_block* block = load->get_parent();
|
||||||
ir::basic_block* header = block->get_predecessors()[0];
|
ir::basic_block* header = block->get_predecessors()[0];
|
||||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||||
@@ -288,22 +298,23 @@ void pipeline::run(ir::module &mod) {
|
|||||||
std::vector<ir::instruction*> insts;
|
std::vector<ir::instruction*> insts;
|
||||||
ir::load_inst* dst;
|
ir::load_inst* dst;
|
||||||
};
|
};
|
||||||
std::map<ir::basic_block*, move_config_t> to_move;
|
std::vector<move_config_t> to_move(to_pipeline.size());
|
||||||
|
|
||||||
if(has_copy_async_){
|
if(has_copy_async_){
|
||||||
for(ir::function* fn: mod.get_function_list())
|
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
|
||||||
for(ir::basic_block* bb: fn->blocks())
|
auto info = to_pipeline[idx];
|
||||||
for(ir::instruction* inst: bb->get_inst_list()){
|
ir::load_inst* load = info.load;
|
||||||
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
ir::phi_node* ptr = info.ptr;
|
||||||
recursive_deps(i, bb, to_move[bb].insts);
|
ir::dot_inst* dot = info.dot;
|
||||||
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
ir::basic_block* bb = dot->get_parent();
|
||||||
to_move[bb].dst = i;
|
recursive_deps(dot, bb, to_move[idx].insts);
|
||||||
|
to_move[idx].dst = load;
|
||||||
}
|
}
|
||||||
|
|
||||||
for(auto& x: to_move){
|
for(auto& move_config: to_move){
|
||||||
builder.set_insert_point_after(x.second.dst);
|
builder.set_insert_point_after(move_config.dst);
|
||||||
for(ir::instruction* i: x.second.insts){
|
for(ir::instruction* i: move_config.insts){
|
||||||
x.first->erase(i);
|
i->get_parent()->erase(i);
|
||||||
builder.insert(i);
|
builder.insert(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'):
|
|||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
|
||||||
|
def test_dot_without_load():
|
||||||
|
@triton.jit
|
||||||
|
def kernel(out, **meta):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
a = tl.zeros((32, 32), tl.float32)
|
||||||
|
b = tl.zeros((32, 32), tl.float32)
|
||||||
|
c = tl.zeros((32, 32), tl.float32)
|
||||||
|
c = tl.dot(a, b)
|
||||||
|
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
|
||||||
|
tl.store(pout, c)
|
||||||
|
|
||||||
|
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
|
||||||
|
kernel[(1,)](out)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test arange
|
# test arange
|
||||||
# ---------------
|
# ---------------
|
||||||
|
Reference in New Issue
Block a user