[codegen][auto-coalesce] more debugging
This commit is contained in:
@@ -119,7 +119,7 @@ private:
|
|||||||
Type *make_vector_ty(Type *ty, size_t vector_size);
|
Type *make_vector_ty(Type *ty, size_t vector_size);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
distributed_tile(Type *ty, const shapes_t& shapes, const axes_t &axes, Builder &builder, bool vectorize);
|
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder, bool vectorize);
|
||||||
void set_value(indices_t idx, Value *v);
|
void set_value(indices_t idx, Value *v);
|
||||||
Value* get_value(indices_t idx);
|
Value* get_value(indices_t idx);
|
||||||
unsigned get_linear_index(indices_t idx);
|
unsigned get_linear_index(indices_t idx);
|
||||||
@@ -129,6 +129,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
axes_t axes_;
|
axes_t axes_;
|
||||||
|
std::vector<int> order_;
|
||||||
indices_map_t indices_;
|
indices_map_t indices_;
|
||||||
values_map_t values_;
|
values_map_t values_;
|
||||||
ordered_indices_vec_t ordered_indices_;
|
ordered_indices_vec_t ordered_indices_;
|
||||||
|
@@ -11,6 +11,8 @@ namespace ir {
|
|||||||
class module;
|
class module;
|
||||||
class value;
|
class value;
|
||||||
class io_inst;
|
class io_inst;
|
||||||
|
class instruction;
|
||||||
|
class builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -27,6 +29,7 @@ class coalesce {
|
|||||||
private:
|
private:
|
||||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||||
|
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem);
|
coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem);
|
||||||
|
@@ -158,6 +158,7 @@ void axes::run(ir::module &mod) {
|
|||||||
unsigned group_id = 0;
|
unsigned group_id = 0;
|
||||||
while(!nodes_.empty())
|
while(!nodes_.empty())
|
||||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||||
|
std::cout << "Number of axes: " << group_id << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -190,6 +190,8 @@ void tiles::run(ir::module &) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
order_[i] = order;
|
order_[i] = order;
|
||||||
|
std::cout << "order: " << order[0] << " " << order[1] << std::endl;
|
||||||
|
|
||||||
}
|
}
|
||||||
// tiling parameters
|
// tiling parameters
|
||||||
for(auto x: largest_){
|
for(auto x: largest_){
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/codegen/selection.h"
|
#include <numeric>
|
||||||
|
#include "triton/codegen/selection.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/codegen/analysis/axes.h"
|
#include "triton/codegen/analysis/axes.h"
|
||||||
@@ -28,6 +29,14 @@ using namespace llvm;
|
|||||||
/* Distributed Tile */
|
/* Distributed Tile */
|
||||||
void distributed_tile::init_indices() {
|
void distributed_tile::init_indices() {
|
||||||
std::vector<size_t> id(axes_.size(), 0);
|
std::vector<size_t> id(axes_.size(), 0);
|
||||||
|
// create iteration order
|
||||||
|
std::vector<size_t> order(id.size());
|
||||||
|
std::iota(order.begin(), order.end(), 0);
|
||||||
|
auto cmp = [&](int x, int y) {
|
||||||
|
return axes_[x].contiguous > axes_[y].contiguous;
|
||||||
|
};
|
||||||
|
std::sort(order.begin(), order.end(), cmp);
|
||||||
|
// build
|
||||||
size_t k = 0;
|
size_t k = 0;
|
||||||
while(true) {
|
while(true) {
|
||||||
indices_t current;
|
indices_t current;
|
||||||
@@ -37,12 +46,12 @@ void distributed_tile::init_indices() {
|
|||||||
indices_[current] = sz;
|
indices_[current] = sz;
|
||||||
values_[current] = nullptr;
|
values_[current] = nullptr;
|
||||||
ordered_indices_.push_back(current);
|
ordered_indices_.push_back(current);
|
||||||
id[0]++;
|
id[order[0]]++;
|
||||||
while(id[k] == axes_[k].values.size()){
|
while(id[order[k]] == axes_[order[k]].values.size()){
|
||||||
if(k == id.size() - 1)
|
if(k == id.size() - 1)
|
||||||
return;
|
return;
|
||||||
id[k++] = 0;
|
id[order[k++]] = 0;
|
||||||
id[k]++;
|
id[order[k]]++;
|
||||||
}
|
}
|
||||||
k = 0;
|
k = 0;
|
||||||
}
|
}
|
||||||
@@ -54,8 +63,8 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size)
|
|||||||
return VectorType::get(ty, vector_size);
|
return VectorType::get(ty, vector_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize)
|
||||||
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) {
|
: tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) {
|
||||||
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
vector_size_ = vectorize?ty_->getVectorNumElements():1;
|
||||||
init_indices();
|
init_indices();
|
||||||
}
|
}
|
||||||
@@ -767,7 +776,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh
|
|||||||
for(ir::user *usr: v->get_users())
|
for(ir::user *usr: v->get_users())
|
||||||
if(dynamic_cast<ir::phi_node*>(usr))
|
if(dynamic_cast<ir::phi_node*>(usr))
|
||||||
has_phi_user = true;
|
has_phi_user = true;
|
||||||
if(has_phi_user){
|
if(!has_phi_user){
|
||||||
size_t offset = alloc_->offset(v);
|
size_t offset = alloc_->offset(v);
|
||||||
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset));
|
||||||
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
ptr = builder.CreateBitCast(ptr, ptr_ty);
|
||||||
@@ -791,7 +800,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
bool vectorize = dynamic_cast<ir::vectorize_inst*>(v);
|
||||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, vectorize);
|
||||||
bool is_inserted = tmap_.insert({v, T}).second;
|
bool is_inserted = tmap_.insert({v, T}).second;
|
||||||
// constant range
|
// constant range
|
||||||
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
if(is_inserted && dynamic_cast<ir::make_range*>(v)){
|
||||||
@@ -1260,8 +1269,9 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
|||||||
// find vector size
|
// find vector size
|
||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
unsigned alignment = alignment_->get(ptr, 0);
|
size_t ld = tiles_->order(ptr)[0];
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
|
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||||
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand());
|
||||||
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
|
distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand());
|
||||||
@@ -1331,8 +1341,9 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
|
|||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
unsigned alignment = alignment_->get(ptr, 0);
|
size_t ld = tiles_->order(ptr)[0];
|
||||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
unsigned alignment = alignment_->get(ptr, ld);
|
||||||
|
unsigned vector_size = std::min<unsigned>(result->axis(ld).contiguous, alignment);
|
||||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||||
// vector loads
|
// vector loads
|
||||||
std::map<unsigned, Value*> packets;
|
std::map<unsigned, Value*> packets;
|
||||||
|
@@ -35,6 +35,31 @@ void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*
|
|||||||
result[axis].push_back(i);
|
result[axis].push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
||||||
|
std::map<ir::value*, ir::value*>& seen) {
|
||||||
|
if(seen.find(x) != seen.end())
|
||||||
|
return seen.at(x);
|
||||||
|
auto i = dynamic_cast<ir::instruction*>(x);
|
||||||
|
// not an instruction -- forward value
|
||||||
|
if(!i)
|
||||||
|
return x;
|
||||||
|
// already in shared memory -- forward value
|
||||||
|
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
// set insert point
|
||||||
|
auto& inst_list = i->get_parent()->get_inst_list();
|
||||||
|
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
|
||||||
|
builder.set_insert_point(pos);
|
||||||
|
// default -- recursive clone
|
||||||
|
ir::instruction *cloned = builder.insert(i->clone());
|
||||||
|
seen[i] = cloned;
|
||||||
|
// rematerialize operands
|
||||||
|
for(ir::value *op: cloned->ops())
|
||||||
|
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
|
||||||
|
return cloned;
|
||||||
|
}
|
||||||
|
|
||||||
void coalesce::run(ir::module &mod) {
|
void coalesce::run(ir::module &mod) {
|
||||||
// find values to rematerialize
|
// find values to rematerialize
|
||||||
size_t num_groups = layout_->get_num_groups();
|
size_t num_groups = layout_->get_num_groups();
|
||||||
@@ -56,54 +81,21 @@ void coalesce::run(ir::module &mod) {
|
|||||||
remat.insert(remat.begin(),
|
remat.insert(remat.begin(),
|
||||||
it->second.begin(), it->second.end());
|
it->second.begin(), it->second.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// rematerialize values
|
// rematerialize values
|
||||||
ir::builder &builder = mod.get_builder();
|
|
||||||
for(ir::io_inst *r: remat) {
|
for(ir::io_inst *r: remat) {
|
||||||
std::list<std::pair<ir::instruction*, ir::instruction*>> work_list;
|
ir::builder& builder = mod.get_builder();
|
||||||
std::map<ir::value*, ir::value*> replaced;
|
// rematerialize operands
|
||||||
work_list.push_back({r, nullptr});
|
std::map<ir::value*, ir::value*> seen;
|
||||||
// rematerialize recursively
|
for(ir::value *op: r->ops())
|
||||||
while(!work_list.empty()) {
|
rematerialize(op, mod.get_builder(), seen);
|
||||||
auto pair = work_list.back();
|
// copy to shared if load
|
||||||
ir::instruction* cloned = pair.first;
|
auto& inst_list = r->get_parent()->get_inst_list();
|
||||||
ir::instruction* original = pair.second;
|
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
|
||||||
work_list.pop_back();
|
builder.set_insert_point(pos);
|
||||||
for(ir::value *op: cloned->ops()) {
|
if(dynamic_cast<ir::load_inst*>(r)){
|
||||||
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
|
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
|
||||||
if(replaced.find(i_op) != replaced.end()){
|
r->replace_all_uses_with(cts);
|
||||||
cloned->replace_uses_of_with(i_op, replaced.at(i_op));
|
cts->replace_uses_of_with(cts, r);
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if(!i_op)
|
|
||||||
continue;
|
|
||||||
ir::type *ty = i_op->get_type();
|
|
||||||
if(!ty->is_tile_ty())
|
|
||||||
continue;
|
|
||||||
auto& inst_list = i_op->get_parent()->get_inst_list();
|
|
||||||
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
|
|
||||||
it++;
|
|
||||||
builder.set_insert_point(it);
|
|
||||||
// found a load; write to shared memory and stop recursion
|
|
||||||
ir::instruction *n_op = nullptr;
|
|
||||||
if(mem_->is_shared(i_op)){
|
|
||||||
i_op->add_use(cloned);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
|
|
||||||
n_op = ir::copy_to_shared_inst::create(ld);
|
|
||||||
// not a load; rematerialize and add to worklist
|
|
||||||
else {
|
|
||||||
n_op = i_op->clone();
|
|
||||||
work_list.push_back({n_op, i_op});
|
|
||||||
}
|
|
||||||
n_op = builder.insert(n_op);
|
|
||||||
replaced.insert({i_op, n_op});
|
|
||||||
mem_->copy(n_op, i_op);
|
|
||||||
if(original)
|
|
||||||
n_op->erase_use(original);
|
|
||||||
cloned->replace_uses_of_with(i_op, n_op);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -92,10 +92,10 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
|||||||
file_type_t ft) {
|
file_type_t ft) {
|
||||||
init_llvm();
|
init_llvm();
|
||||||
// debug
|
// debug
|
||||||
// llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||||
// pm.add(llvm::createVerifierPass());
|
// pm.add(llvm::createVerifierPass());
|
||||||
// pm.run(*module);
|
pm.run(*module);
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
// std::cout << source_ << std::endl;
|
std::cout << source_ << std::endl;
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -48,8 +48,10 @@ void print(module &mod, std::ostream& os) {
|
|||||||
os << std::endl;
|
os << std::endl;
|
||||||
for(ir::instruction *inst: block->get_inst_list()){
|
for(ir::instruction *inst: block->get_inst_list()){
|
||||||
os << " ";
|
os << " ";
|
||||||
|
if(!inst->get_type()->is_void_ty()){
|
||||||
os << get_name(inst, cnt++);
|
os << get_name(inst, cnt++);
|
||||||
os << " = ";
|
os << " = ";
|
||||||
|
}
|
||||||
ir::type* type = inst->get_type();
|
ir::type* type = inst->get_type();
|
||||||
os << inst->repr() << " " << type->repr();
|
os << inst->repr() << " " << type->repr();
|
||||||
ir::instruction::ops_t ops = inst->ops();
|
ir::instruction::ops_t ops = inst->ops();
|
||||||
@@ -65,7 +67,6 @@ void print(module &mod, std::ostream& os) {
|
|||||||
}
|
}
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
}
|
}
|
||||||
os << std::endl;
|
|
||||||
}
|
}
|
||||||
os << "}" << std::endl;
|
os << "}" << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -221,6 +221,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
axes.run(module);
|
axes.run(module);
|
||||||
layouts.run(module);
|
layouts.run(module);
|
||||||
coalesce.run(module);
|
coalesce.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
align.run(module);
|
align.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
tiles.run(module);
|
tiles.run(module);
|
||||||
|
@@ -48,7 +48,7 @@ int main() {
|
|||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto x: std::vector<order_t>{COLMAJOR}){
|
for(auto x: std::vector<order_t>{COLMAJOR}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{2048, 2048, x}
|
config_t{4096, 4096, x}
|
||||||
};
|
};
|
||||||
configs.insert(configs.end(), tmp.begin(), tmp.end());
|
configs.insert(configs.end(), tmp.begin(), tmp.end());
|
||||||
}
|
}
|
||||||
|
@@ -38,7 +38,7 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
|||||||
int rm[TM] = ridm * TM + 0 ... TM;
|
int rm[TM] = ridm * TM + 0 ... TM;
|
||||||
int rn[TN] = ridn * TN + 0 ... TN;
|
int rn[TN] = ridn * TN + 0 ... TN;
|
||||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] * ldy + rn[newaxis, :];
|
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy;
|
||||||
*py = *px;
|
*py = *px;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
Reference in New Issue
Block a user