[codegen][coalesce] some bugfix for phi-nodes
This commit is contained in:
@@ -19,7 +19,7 @@ namespace ir{
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
class reorder;
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
namespace analysis{
|
||||
@@ -46,7 +46,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
grids(size_t num_warps, transform::reorder* reorder);
|
||||
grids(size_t num_warps, transform::coalesce* reorder);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
@@ -66,7 +66,7 @@ private:
|
||||
std::vector<ir::value*> grids_;
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
size_t num_warps_;
|
||||
transform::reorder* reorder_;
|
||||
transform::coalesce* reorder_;
|
||||
|
||||
};
|
||||
|
||||
|
@@ -50,7 +50,7 @@ class meminfo;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
class reorder;
|
||||
class coalesce;
|
||||
}
|
||||
|
||||
class target;
|
||||
@@ -195,7 +195,7 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::reorder* reorder, target *tgt)
|
||||
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
@@ -207,7 +207,7 @@ private:
|
||||
analysis::grids *params_;
|
||||
analysis::meminfo *buffer_info_;
|
||||
analysis::align *alignment_;
|
||||
transform::reorder *reorder_;
|
||||
transform::coalesce *reorder_;
|
||||
target *tgt_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
Value *sh_mem_ptr_;
|
||||
|
@@ -20,9 +20,9 @@ namespace analysis{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
class coalesce {
|
||||
public:
|
||||
reorder(analysis::align* algin, analysis::meminfo* mem);
|
||||
coalesce(analysis::align* algin, analysis::meminfo* mem);
|
||||
std::vector<unsigned> get_order(ir::value* v);
|
||||
void run(ir::module &mod);
|
||||
|
@@ -1,6 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
@@ -16,7 +16,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
|
||||
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), reorder_(reorder)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -298,7 +298,7 @@ void grids::run(ir::module &mod) {
|
||||
unsigned current = num_threads;
|
||||
std::string nts = "nts.d" + s_ld;
|
||||
std::string mts = "mts.d" + s_ld;
|
||||
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 1));
|
||||
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 8));
|
||||
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
|
||||
current = current / params_.at(i).at(mts)->get_value();
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
|
@@ -3,7 +3,7 @@
|
||||
#include "triton/codegen/analysis/grid.h"
|
||||
#include "triton/codegen/analysis/memalloc.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
110
lib/codegen/transform/coalesce.cc
Normal file
110
lib/codegen/transform/coalesce.cc
Normal file
@@ -0,0 +1,110 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::meminfo *mem)
|
||||
: align_(align), mem_(mem) { }
|
||||
|
||||
std::vector<unsigned> coalesce::get_order(ir::value* v) {
|
||||
return order_.at(v);
|
||||
}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
|
||||
std::set<ir::io_inst*> io;
|
||||
|
||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
||||
if(order_.find(v) != order_.end())
|
||||
return;
|
||||
order_[v] = {};
|
||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value* op: u->ops())
|
||||
set_order(op);
|
||||
ir::type* ty = v->get_type();
|
||||
if(!ty->is_tile_ty())
|
||||
return;
|
||||
std::vector<unsigned> order(ty->get_tile_shapes().size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order_[v] = order;
|
||||
};
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
|
||||
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
|
||||
if(ptr_ty->is_tile_ty())
|
||||
io.insert(x);
|
||||
}
|
||||
set_order(i);
|
||||
}
|
||||
|
||||
// ir::builder &builder = mod.get_builder();
|
||||
// std::set<ir::value*> seen;
|
||||
// for(ir::io_inst *i: io) {
|
||||
// ir::value *ptr = i->get_pointer_operand();
|
||||
// auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
||||
// std::vector<unsigned> order(max_contiguous.size());
|
||||
// std::iota(order.begin(), order.end(), 0);
|
||||
// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||
// std::list<ir::instruction*> work_list;
|
||||
// if(order != order_[i])
|
||||
// work_list.push_back(i);
|
||||
// // rematerialize recursively
|
||||
// while(!work_list.empty()) {
|
||||
// ir::instruction* current = work_list.back();
|
||||
// order_[current] = order;
|
||||
// work_list.pop_back();
|
||||
// for(ir::value *op: current->ops()) {
|
||||
// ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||
// if(!seen.insert(op).second)
|
||||
// continue;
|
||||
// if(!i_op)
|
||||
// continue;
|
||||
// ir::type *ty = i_op->get_type();
|
||||
// if(!ty->is_tile_ty())
|
||||
// continue;
|
||||
// auto& inst_list = i_op->get_parent()->get_inst_list();
|
||||
// auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
||||
// it++;
|
||||
// builder.set_insert_point(it);
|
||||
// // found a load; write to shared memory and stop recursion
|
||||
// ir::instruction *n_op = nullptr;
|
||||
// if(mem_->is_shared(i_op)){
|
||||
// continue;
|
||||
// }
|
||||
// if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
|
||||
// n_op = ir::copy_to_shared_inst::create(ld);
|
||||
// }
|
||||
// // not a load; rematerialize and recurse
|
||||
// else {
|
||||
// n_op = i_op->clone();
|
||||
// work_list.push_back(n_op);
|
||||
// }
|
||||
// n_op = builder.insert(n_op);
|
||||
// order_[n_op] = order;
|
||||
// align_->copy(n_op, i_op);
|
||||
// current->replace_uses_of_with(i_op, n_op);
|
||||
// }
|
||||
// }
|
||||
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,106 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/analysis/meminfo.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
reorder::reorder(analysis::align* align, analysis::meminfo *mem)
|
||||
: align_(align), mem_(mem) { }
|
||||
|
||||
std::vector<unsigned> reorder::get_order(ir::value* v) {
|
||||
return order_.at(v);
|
||||
}
|
||||
|
||||
void reorder::run(ir::module &mod) {
|
||||
|
||||
std::set<ir::io_inst*> io;
|
||||
|
||||
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
|
||||
if(order_.find(v) != order_.end())
|
||||
return;
|
||||
if(ir::user* u = dynamic_cast<ir::user*>(v))
|
||||
for(ir::value* op: u->ops())
|
||||
set_order(op);
|
||||
ir::type* ty = v->get_type();
|
||||
if(!ty->is_tile_ty())
|
||||
return;
|
||||
std::vector<unsigned> order(ty->get_tile_shapes().size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order_[v] = order;
|
||||
};
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
|
||||
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
|
||||
if(ptr_ty->is_tile_ty())
|
||||
io.insert(x);
|
||||
}
|
||||
set_order(i);
|
||||
}
|
||||
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::io_inst *i: io) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
|
||||
std::vector<unsigned> order(max_contiguous.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
|
||||
std::list<ir::instruction*> work_list;
|
||||
if(order != order_[i])
|
||||
work_list.push_back(i);
|
||||
// rematerialize recursively
|
||||
while(!work_list.empty()) {
|
||||
ir::instruction* current = work_list.back();
|
||||
order_[current] = order;
|
||||
work_list.pop_back();
|
||||
for(ir::value *op: current->ops()) {
|
||||
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
||||
if(!i_op)
|
||||
continue;
|
||||
ir::type *ty = i_op->get_type();
|
||||
if(!ty->is_tile_ty())
|
||||
continue;
|
||||
auto& inst_list = i_op->get_parent()->get_inst_list();
|
||||
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
// found a load; write to shared memory and stop recursion
|
||||
ir::instruction *n_op = nullptr;
|
||||
if(mem_->is_shared(i_op)){
|
||||
continue;
|
||||
}
|
||||
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
|
||||
n_op = ir::copy_to_shared_inst::create(ld);
|
||||
}
|
||||
// not a load; rematerialize and recurse
|
||||
else {
|
||||
n_op = i_op->clone();
|
||||
work_list.push_back(n_op);
|
||||
}
|
||||
n_op = builder.insert(n_op);
|
||||
order_[n_op] = order;
|
||||
align_->copy(n_op, i_op);
|
||||
current->replace_uses_of_with(i_op, n_op);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -197,7 +197,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::analysis::meminfo shmem_info;
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
|
||||
codegen::transform::coalesce reorder(&alignment_info, &shmem_info);
|
||||
codegen::analysis::grids grids(opt.num_warps, &reorder);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
@@ -215,7 +215,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// ir::print(module, std::cout);
|
||||
reorder.run(module);
|
||||
dce.run(module);
|
||||
ir::print(module, std::cout);
|
||||
// ir::print(module, std::cout);
|
||||
grids.run(module);
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
@@ -231,7 +231,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
dce.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
ir::print(module, std::cout);
|
||||
// ir::print(module, std::cout);
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
|
Reference in New Issue
Block a user