diff --git a/include/triton/codegen/analysis/grid.h b/include/triton/codegen/analysis/grid.h index 1eb352b00..25c5d24a4 100644 --- a/include/triton/codegen/analysis/grid.h +++ b/include/triton/codegen/analysis/grid.h @@ -49,7 +49,7 @@ public: grids(size_t num_warps, transform::coalesce* reorder); ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; } unsigned get_param_group(ir::value *value, unsigned ax); - fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); } + fragment_t get_fragment(ir::value *value, unsigned ax); void copy(ir::value *dst, ir::value *src); void run(ir::module &mod); unsigned get_num_threads(); diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index 6a5169d13..aa87c9480 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -182,6 +182,11 @@ unsigned grids::get_param_group(ir::value *value, unsigned ax) { return result; } +grids::fragment_t grids::get_fragment(ir::value *value, unsigned ax) { + return fragments_.at({value, ax}); +} + + //TODO: This shouldn't exist! void grids::copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 87a2adf58..5b137c148 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -723,8 +723,9 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, if(!v->get_type()->is_tile_ty() || !seen.insert(v).second) return; if(auto *user = dynamic_cast(v)) - for(ir::value *op: user->ops()) + for(ir::value *op: user->ops()){ create_tile(op, builder, seen, sh_mem_ptr); + } LLVMContext &ctx = builder.getContext(); auto shapes = v->get_type()->get_tile_shapes(); unsigned pad = alloc_->is_ld_padded(v); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index b9fbbb534..0ba534531 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -61,18 +61,20 @@ void coalesce::run(ir::module &mod) { std::vector order(max_contiguous.size()); std::iota(order.begin(), order.end(), 0); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); - std::list work_list; + std::list> work_list; if(order != order_[i]) - work_list.push_back(i); + work_list.push_back({i, nullptr}); // rematerialize recursively while(!work_list.empty()) { - ir::instruction* current = work_list.back(); - order_[current] = order; + auto pair = work_list.back(); + ir::instruction* cloned = pair.first; + ir::instruction* original = pair.second; + order_[cloned] = order; work_list.pop_back(); - for(ir::value *op: current->ops()) { + for(ir::value *op: cloned->ops()) { ir::instruction* i_op = dynamic_cast(op); if(replaced.find(i_op) != replaced.end()){ - current->replace_uses_of_with(i_op, replaced.at(i_op)); + cloned->replace_uses_of_with(i_op, replaced.at(i_op)); continue; } if(!i_op) @@ -90,17 +92,19 @@ void coalesce::run(ir::module &mod) { continue; if(auto* ld = dynamic_cast(i_op)) n_op = ir::copy_to_shared_inst::create(ld); - // not a load; rematerialize and recurse + // not a load; rematerialize and add to worklist else { n_op = i_op->clone(); - work_list.push_back(n_op); + work_list.push_back({n_op, i_op}); } n_op = builder.insert(n_op); replaced.insert({i_op, n_op}); order_[n_op] = order; align_->copy(n_op, i_op); -// mem_->copy(n_op, i_op); - current->replace_uses_of_with(i_op, n_op); + mem_->copy(n_op, i_op); + if(original) + n_op->erase_use(original); + cloned->replace_uses_of_with(i_op, n_op); } } diff --git a/lib/ir/value.cc b/lib/ir/value.cc index 3ab64b97a..5dfb0460c 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -66,7 +66,7 @@ void user::replace_uses_of_with(value *before, value *after) { if(ops_[i] == before) ops_[i] = after; after->add_use(this); - erase_use(this); + before->erase_use(this); } }