[code generation] some more bugfixing with nested control flow

This commit is contained in:
Philippe Tillet
2019-02-18 22:54:08 -05:00
parent f3094a512b
commit 90ec0ae2c0
10 changed files with 102 additions and 71 deletions

View File

@@ -63,6 +63,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
@checka a = *pa;\
@checkb b = *pb;\
if(k <= 8){\
@checka a = *pa;\
}\
}\
@checkc *pc = C;\
@@ -170,11 +171,10 @@ int main() {
llvm::LLVMContext llvm_context;
llvm::Module llvm_module("test", llvm_context);
tdl::ir::print(module, std::cout);
// create passes
tdl::codegen::place_shared_copy shared;
tdl::codegen::buffer_info_pass buffer_info;
tdl::codegen::place_shared_copy shared(&buffer_info);
tdl::codegen::tune tune;
tdl::codegen::liveness liveness(&buffer_info);
tdl::codegen::allocation allocation(&liveness, &buffer_info);
@@ -211,10 +211,14 @@ int main() {
if(errors.size())
exit(EXIT_FAILURE);
// print
// run passes
shared.run(module);
tdl::ir::print(module, std::cout);
buffer_info.run(module);
shared.run(module);
liveness.run(module);
tdl::ir::print(module, std::cout);
allocation.run();
barriers.run(module);
vectorize.run(module);

View File

@@ -15,6 +15,8 @@ namespace ir {
namespace codegen{
class buffer_info_pass;
class place_shared_copy {
private:
typedef std::pair<unsigned, unsigned> interval_t;
@@ -22,10 +24,14 @@ private:
private:
bool intersect(const interval_vec_t &I, interval_t i);
void add_copies(ir::value *x, ir::builder &builder);
void add_copy(ir::value *x, ir::builder &builder);
public:
place_shared_copy(buffer_info_pass *info): info_(info) { }
void run(ir::module &mod);
private:
buffer_info_pass *info_;
};

View File

@@ -30,7 +30,7 @@ public:
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }

View File

@@ -292,6 +292,10 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn);
ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr;
ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn);
mod->seal_block(then_bb);
if(else_value_)
mod->seal_block(else_bb);
// Branch
if(else_value_)
builder.create_cond_br(cond, then_bb, else_bb);
@@ -301,13 +305,11 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
builder.set_insert_point(then_bb);
then_value_->codegen(mod);
builder.create_br(endif_bb);
mod->seal_block(then_bb);
// Else
if(else_value_){
builder.set_insert_point(else_bb);
else_value_->codegen(mod);
builder.create_br(endif_bb);
mod->seal_block(else_bb);
}
// Endif
builder.set_insert_point(endif_bb);

View File

@@ -102,10 +102,14 @@ void allocation::run(){
for(ir::value *y: interferences[x])
Adj = std::max(Adj, starts[y] + get_num_bytes(y));
offsets_[x] = starts[x] + colors[x] * Adj;
if(auto *phi = dynamic_cast<ir::phi_node*>(x))
for(ir::value *px: phi->ops()){
if(offsets_.find(px) == offsets_.end())
offsets_[px] = offsets_[x];
if(buffer_info_->is_double(x)){
ir::phi_node *phi = (ir::phi_node*)x;
for(unsigned i = 0; i < phi->get_num_incoming(); i++){
ir::value *inc_val = phi->get_incoming_value(i);
assert(offsets_.find(inc_val) == offsets_.end());
offsets_[inc_val] = offsets_[phi];
std::cout << x->get_name() << " " << inc_val->get_name() << " " << inc_val << std::endl;
}
}
}

View File

@@ -26,7 +26,7 @@ bool barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
}
void barriers::add_reference(ir::value *v, interval_vec_t &res){
if(buffer_info_->is_shared(v)){
if(dynamic_cast<ir::copy_to_shared_inst*>(v)){
unsigned offset = alloc_->get_offset(v);
unsigned num_bytes = alloc_->get_num_bytes(v);
res.push_back(interval_t(offset, offset + num_bytes));
@@ -51,7 +51,7 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
builder.create_barrier();
}
}
else{
else {
builder.set_insert_point(instr);
builder.create_barrier();
}

View File

@@ -12,25 +12,37 @@ namespace codegen{
// run pass on module
void buffer_info_pass::run(ir::module &mod) {
// Find which buffers are shared
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(dynamic_cast<ir::matmul_inst*>(i)){
shared_.insert(i->get_operand(0));
shared_.insert(i->get_operand(1));
}
// Handles phi nodes
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()) {
if(!i->get_type()->is_tile_ty())
continue;
// handle phi
if(auto *phi = dynamic_cast<ir::phi_node*>(i)){
if(auto *phi = dynamic_cast<ir::phi_node*>(i))
if(is_shared(phi)){
// determine if the value is in shared memory
bool is_shared = true;
bool is_double = false;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::value *inc_block = phi->get_incoming_block(n);
is_shared = is_shared && dynamic_cast<ir::copy_to_shared_inst*>(inc_val);
is_double = is_double || inc_block == phi->get_parent();
ir::basic_block *inc_block = phi->get_incoming_block(n);
ir::value *terminator = inc_block->get_inst_list().back();
if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
is_double = is_double || br->get_true_dest() == phi->get_parent()
|| br->get_false_dest() == phi->get_parent();
else if(auto *br = dynamic_cast<ir::uncond_branch_inst*>(terminator))
is_double = is_double || br->get_dest() == phi->get_parent();
else
throw std::runtime_error("unreachable");
}
// add to shared
if(is_shared)
shared_.insert(phi);
// add to double-buffered
if(is_double)
double_.insert(phi);
@@ -41,10 +53,10 @@ void buffer_info_pass::run(ir::module &mod) {
refs_[inc_val] = phi;
}
}
// handle shared copy
if(auto *copy = dynamic_cast<ir::copy_to_shared_inst*>(i))
shared_.insert(copy);
}
for(auto &ref: refs_)
shared_.insert(ref.first);
}
// query double-buffered status

View File

@@ -299,6 +299,7 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
std::cout << "name: " << v->get_name() << std::endl;
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
@@ -354,7 +355,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
if(buffer_info_->is_shared(v))
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || buffer_info_->is_double(v))
return;
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] == 1)
@@ -388,7 +389,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
const auto& shapes = v->get_type()->get_tile_shapes();
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile
if(buffer_info_->is_shared(v)){
if(dynamic_cast<ir::copy_to_shared_inst*>(v) || (buffer_info_->is_double(v))){
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
if(dynamic_cast<ir::copy_to_shared_inst*>(v)) {
@@ -478,6 +479,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
std::cout << "lowering " << ins->get_name() << std::endl;
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
Function *function = block->getParent();
@@ -602,7 +604,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ti->set_value(idx, in->get_value(idx));
});
}
else if(buffer_info_->is_shared(ins))
else if(dynamic_cast<ir::copy_to_shared_inst*>(ins) || (buffer_info_->is_double(ins)))
return;
// matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(ins)) {
@@ -694,13 +696,15 @@ void selection::run(ir::module &src, Module &dst){
std::map<ir::basic_block*, BasicBlock*> last_block;
// iterate through block
for(ir::basic_block *block: fn->blocks()) {
std::cout << "block: " << block->get_name() << std::endl;
BasicBlock *parent = (BasicBlock*)vmap_[block];
dst_builder.SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::phi_node*>(i))
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty()){
dst_builder.SetInsertPoint(&*parent->getFirstInsertionPt());
}
lower_instruction(i, dst_builder);
if(dynamic_cast<ir::phi_node*>(i))
if(dynamic_cast<ir::phi_node*>(i) && !parent->empty())
dst_builder.SetInsertPoint(parent);
last_block[block] = dst_builder.GetInsertBlock();
}
@@ -709,7 +713,7 @@ void selection::run(ir::module &src, Module &dst){
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){
if(buffer_info_->is_shared(phi)) {
if(buffer_info_->is_double(phi)) {
PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer();
PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset();
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
@@ -728,25 +732,28 @@ void selection::run(ir::module &src, Module &dst){
}
ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block);
}
continue;
}
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = last_block.at(inc_block);
if(phi->get_type()->is_tile_ty()) {
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
phi_tile->for_each([&](indices_t idx){
PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
Value *llvm_inc_val = inc_tile->get_value(idx);
else {
std::cout << "phi: " << phi->get_name() << std::endl;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::value *inc_val = phi->get_incoming_value(n);
ir::basic_block *inc_block = phi->get_incoming_block(n);
BasicBlock *llvm_inc_block = last_block.at(inc_block);
std::cout << "incoming block: " << inc_block->get_name() << " " << llvm_inc_block->getName().str() << std::endl;
if(phi->get_type()->is_tile_ty()) {
distributed_tile *phi_tile = (distributed_tile*)tmap_.at(phi);
distributed_tile *inc_tile = (distributed_tile*)tmap_.at(inc_val);
phi_tile->for_each([&](indices_t idx){
PHINode *llvm_phi = (PHINode*)phi_tile->get_value(idx);
Value *llvm_inc_val = inc_tile->get_value(idx);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
});
}
else {
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
}
}
}
}

View File

@@ -1,5 +1,6 @@
#include <algorithm>
#include "codegen/shared_copy.h"
#include "codegen/buffer_info.h"
#include "ir/module.h"
#include "ir/function.h"
#include "ir/basic_block.h"
@@ -9,21 +10,16 @@ namespace tdl {
namespace codegen{
void place_shared_copy::add_copies(ir::value *x, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(x)) {
for(auto *op: phi->ops())
add_copies(op, builder);
}
else {
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) {
if(auto *i = dynamic_cast<ir::instruction*>(x)){
ir::basic_block* block = i->get_parent();
std::cout << "adding copy: " << x->get_name() << " " << block->get_name() << std::endl;
auto it = std::find(block->begin(), block->end(), i);
builder.set_insert_point(++it);
}
ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
}
void place_shared_copy::run(ir::module &mod) {
@@ -31,10 +27,8 @@ void place_shared_copy::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(dynamic_cast<ir::matmul_inst*>(i)){
add_copies(i->get_operand(0), builder);
add_copies(i->get_operand(1), builder);
}
if(info_->is_shared(i) && !info_->is_double(i))
add_copy(i, builder);
}
}

View File

@@ -48,7 +48,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
return res;
}
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi, ir::value** pre_user){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
@@ -61,12 +61,12 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
std::set<ir::user*> users = phi->get_users();
phi->replace_all_uses_with(same);
phi->erase_from_parent();
if(pre_user)
*pre_user = same;
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
if(auto *new_phi = dynamic_cast<ir::phi_node*>(same))
return try_remove_trivial_phis(new_phi);
try_remove_trivial_phis(uphi, &same);
return same;
}
@@ -80,10 +80,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return try_remove_trivial_phis(phi);
return try_remove_trivial_phis(phi, nullptr);
}
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
std::cout << "getting value " << name << std::endl;
ir::value *result;
auto &preds = block->get_predecessors();
if(block)
@@ -141,6 +142,7 @@ void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block])
add_phi_operands(x.first, x.second);
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
}
/* functions */