[code generation] fixed bug in on-the-fly AST to IR lowering

This commit is contained in:
Philippe Tillet
2019-01-23 00:11:42 -05:00
parent a0ecdba5a2
commit 7eebdceb6a
10 changed files with 344 additions and 102 deletions

View File

@@ -24,21 +24,16 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
int32 rx[16] = get_global_range[16](0);\
int32 ry[16] = get_global_range[16](1);\
int32 rx[32] = get_global_range[32](0);\
int32 ry[32] = get_global_range[32](1);\
int32 rka[8] = 0 ... 8;\
int32 rkb[8] = 0 ... 8;\
fp32 C[16, 16] = 0;\
fp32 C[32, 32] = 0;\
int32 k;\
fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[16, 16];\
fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[32, 32];\
for(k = 0; k < K; k = k + 8){\
fp32 A[16, 8] = *pa;\
fp32 B[16, 8] = *pb;\
C = dot(A, B, C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
}\
pc = c + rx[:, newaxis] + ry[newaxis, :];\
*pc = C;\
@@ -60,13 +55,37 @@ int main() {
tdl::codegen::tune tune;
tdl::codegen::liveness liveness;
tdl::codegen::allocation allocation(&liveness);
tdl::codegen::selection selection(&allocation, &tune);
tune.run(module);
std::vector<unsigned> params = {
// asm
2, 16, 1,
// bsn
2, 16, 1,
// pa
1, 2, 4,
// pb
1, 2, 4,
// c
2, 16, 1, 1, 2, 4
};
std::map<tdl::ir::value*, std::vector<std::string>> errors;
unsigned i = 0;
std::cout << tune.get_params(module).size() << std::endl;
for(unsigned *x: tune.get_params(module))
*x = params[i++];
tune.check_constraints(module, errors);
// std::cout << "errors: " << errors.size() << std::endl;
// for(auto &x: errors){
// for(auto &e: x.second)
// std::cout << e << std::endl;
// }
shared.run(module);
liveness.run(module);
allocation.run();
std::vector<unsigned*> params;
tune.get_params(module, params);
std::cout << params.size() << std::endl;
selection.run(module, llvm_module);
// std::vector<unsigned*> params = tune.get_params(module);
// std::cout << params.size() << std::endl;
// selection.run(module, llvm_module);
// // print LLVM program
// llvm::PrintModulePass print(llvm::outs());

View File

@@ -65,6 +65,7 @@ class node {
protected:
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty);
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
public:

View File

@@ -21,32 +21,75 @@ namespace tdl{
namespace codegen{
class allocation;
class tune;
struct distributed_axis {
std::vector<llvm::Value*> values;
};
class tile {
protected:
typedef std::vector<unsigned> shapes_t;
public:
tile(const shapes_t &shapes): shapes_(shapes){ }
private:
shapes_t shapes_;
};
class shared_tile: public tile {
public:
using tile::tile;
};
class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t;
public:
distributed_tile(const shapes_t& shapes, const axes_t &axes)
: tile(shapes), axes_(axes) {}
private:
axes_t axes_;
};
class selection{
typedef std::map<ir::value *, llvm::Value *> vmap_t;
typedef std::map<ir::basic_block *, llvm::BasicBlock *> bmap_t;
typedef std::map<ir::value *, tile *> tmap_t;
private:
// LLVM conversions
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
llvm::Value* llvm_value(ir::value *v,llvm:: LLVMContext &ctx);
llvm::Instruction* llvm_inst(ir::instruction *inst, llvm::LLVMContext &ctx);
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
// grid construction
void create_grids(std::vector<ir::instruction*> &grids,
std::map<unsigned*, ir::instruction*> &references,
ir::function *fn);
void init_axes(ir::instruction *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder);
// lowering
void lower_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
public:
selection(allocation *alloc): alloc_(alloc){ }
selection(allocation *alloc, tune *params): alloc_(alloc), params_(params){ }
void run(ir::module &src, llvm::Module &dst);
private:
vmap_t vmap_;
bmap_t bmap_;
tmap_t tmap_;
allocation *alloc_;
tune *params_;
std::map<ir::instruction*, std::vector<distributed_axis>> axes_;
};
}

View File

@@ -11,6 +11,7 @@ namespace ir{
class value;
class module;
class instruction;
class function;
}
namespace codegen{
@@ -24,11 +25,13 @@ private:
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
void connected_components(node_t x, const std::vector<unsigned*> vals, std::set<node_t> &nodes, graph_t &graph);
void create_grids(std::vector<ir::instruction*> &grids, std::map<unsigned*, ir::instruction*> &references, ir::function *fn);
public:
void get_params(ir::module& mod, std::vector<unsigned*> &result);
unsigned *get_param(ir::value *value);
std::vector<unsigned *> get_params(ir::module& mod);
std::map<std::string, unsigned *> get_params(ir::instruction* i);
unsigned *get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);

View File

@@ -149,12 +149,15 @@ protected:
//===----------------------------------------------------------------------===//
class cast_inst: public unary_inst{
using unary_inst::unary_inst;
using ic = llvm::Instruction::CastOps;
public:
typedef llvm::CastInst::CastOps op_t;
protected:
cast_inst(type *ty, value *v, const std::string &name, instruction *next, op_t op)
: unary_inst(ty, v, name, next), op_(op) { }
private:
static bool is_valid(op_t op, value *arg, type *ty);
@@ -172,25 +175,26 @@ private:
op_t op_;
};
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name) \
class name : public cast_inst{ \
friend class cast_inst; \
using cast_inst::cast_inst; \
};
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name, op) \
class name : public cast_inst{ \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, v, name, next, op){ } \
};
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst)
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst, llvm::Instruction::CastOps::Trunc)
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst, llvm::Instruction::CastOps::ZExt)
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst, llvm::Instruction::CastOps::SExt)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes

View File

@@ -90,6 +90,11 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
throw std::runtime_error("unreachable");
}
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
ir::value *tmp = ir::undef_value::get(ty);
implicit_broadcast(mod, arg, tmp);
}
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
ir::builder &builder = mod->get_builder();
ir::type *lhs_ty = lhs->get_type();
@@ -320,7 +325,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
if(expr_){
value = expr_->codegen(mod);
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
implicit_broadcast(mod, value, value);
implicit_broadcast(mod, value, ty);
}
value->set_name(name);
mod->set_value(name, value);
@@ -331,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
/* Expression */
/*------------------*/
/* Binary operator */
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
{
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, lhs, rhs);
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, arg, rhs);
if(op_==MUL && is_float)
return builder.create_fmul(lhs, rhs, name);
return builder.create_fmul(arg, rhs, name);
if(op_==MUL && is_int)
return builder.create_mul(lhs, rhs, name);
return builder.create_mul(arg, rhs, name);
if(op_==DIV && is_float)
return builder.create_fdiv(lhs, rhs, name);
return builder.create_fdiv(arg, rhs, name);
if(op_==DIV && is_int && is_signed)
return builder.create_sdiv(lhs, rhs, name);
return builder.create_sdiv(arg, rhs, name);
if(op_==DIV && is_int && !is_signed)
return builder.create_udiv(lhs, rhs, name);
return builder.create_udiv(arg, rhs, name);
if(op_==MOD && is_float)
return builder.create_frem(lhs, rhs, name);
return builder.create_frem(arg, rhs, name);
if(op_==MOD && is_int && is_signed)
return builder.create_srem(lhs, rhs, name);
return builder.create_srem(arg, rhs, name);
if(op_==MOD && is_int && !is_signed)
return builder.create_urem(lhs, rhs, name);
return builder.create_urem(arg, rhs, name);
if(op_==ADD && is_float)
return builder.create_fadd(lhs, rhs, name);
return builder.create_fadd(arg, rhs, name);
if(op_==ADD && is_int)
return builder.create_add(lhs, rhs);
return builder.create_add(arg, rhs);
if(op_==ADD && is_ptr)
return builder.create_gep(lhs, {rhs});
return builder.create_gep(arg, {rhs});
if(op_==SUB && is_float)
return builder.create_fsub(lhs, rhs, name);
return builder.create_fsub(arg, rhs, name);
if(op_==SUB && is_int)
return builder.create_sub(lhs, rhs, name);
return builder.create_sub(arg, rhs, name);
if(op_==SUB && is_ptr)
return builder.create_gep(lhs, {builder.create_neg(rhs)});
return builder.create_gep(arg, {builder.create_neg(rhs)});
if(op_==LEFT_SHIFT)
return builder.create_shl(lhs, rhs, name);
return builder.create_shl(arg, rhs, name);
if(op_==RIGHT_SHIFT)
return builder.create_ashr(lhs, rhs, name);
return builder.create_ashr(arg, rhs, name);
if(op_ == LT && is_float)
return builder.create_fcmpOLT(lhs, rhs, name);
return builder.create_fcmpOLT(arg, rhs, name);
if(op_ == LT && is_int && is_signed)
return builder.create_icmpSLT(lhs, rhs, name);
return builder.create_icmpSLT(arg, rhs, name);
if(op_ == LT && is_int && !is_signed)
return builder.create_icmpULT(lhs, rhs, name);
return builder.create_icmpULT(arg, rhs, name);
if(op_ == GT && is_float)
return builder.create_fcmpOGT(lhs, rhs, name);
return builder.create_fcmpOGT(arg, rhs, name);
if(op_ == GT && is_int && is_signed)
return builder.create_icmpSGT(lhs, rhs, name);
return builder.create_icmpSGT(arg, rhs, name);
if(op_ == GT && is_int && !is_signed)
return builder.create_icmpUGT(lhs, rhs, name);
return builder.create_icmpUGT(arg, rhs, name);
if(op_ == LE && is_float)
return builder.create_fcmpOLE(lhs, rhs, name);
return builder.create_fcmpOLE(arg, rhs, name);
if(op_ == LE && is_int && is_signed)
return builder.create_icmpSLE(lhs, rhs, name);
return builder.create_icmpSLE(arg, rhs, name);
if(op_ == LE && is_int && !is_signed)
return builder.create_icmpULE(lhs, rhs, name);
return builder.create_icmpULE(arg, rhs, name);
if(op_ == GE && is_float)
return builder.create_fcmpOGE(lhs, rhs, name);
return builder.create_fcmpOGE(arg, rhs, name);
if(op_ == GE && is_int && is_signed)
return builder.create_icmpSGE(lhs, rhs, name);
return builder.create_icmpSGE(arg, rhs, name);
if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(lhs, rhs, name);
return builder.create_icmpUGE(arg, rhs, name);
if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(lhs, rhs, name);
return builder.create_fcmpOEQ(arg, rhs, name);
if(op_ == EQ && is_int)
return builder.create_icmpEQ(lhs, rhs, name);
return builder.create_icmpEQ(arg, rhs, name);
if(op_ == NE && is_float)
return builder.create_fcmpONE(lhs, rhs, name);
return builder.create_fcmpONE(arg, rhs, name);
if(op_ == NE && is_int)
return builder.create_icmpNE(lhs, rhs, name);
return builder.create_icmpNE(arg, rhs, name);
if(op_ == AND)
return builder.create_and(lhs, rhs, name);
return builder.create_and(arg, rhs, name);
if(op_ == XOR)
return builder.create_xor(lhs, rhs, name);
return builder.create_xor(arg, rhs, name);
if(op_ == OR)
return builder.create_or(lhs, rhs, name);
return builder.create_or(arg, rhs, name);
if(op_ == LAND)
return builder.create_and(lhs, rhs, name);
return builder.create_and(arg, rhs, name);
if(op_ == LOR)
return builder.create_or(lhs, rhs, name);
return builder.create_or(arg, rhs, name);
throw std::runtime_error("unreachable");
}
@@ -433,6 +438,12 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod);
ir::value *B = B_->codegen(mod);
ir::value *C = C_->codegen(mod);
// unsigned M = A->get_type()->get_tile_shapes()[0];
// unsigned N = B->get_type()->get_tile_shapes()[1];
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
// ir::value *tmp = ir::undef_value::get(tile_ty);
// implicit_broadcast(mod, tmp, C);
return mod->get_builder().create_matmul(A, B, C);
}

View File

@@ -1,4 +1,5 @@
#include "codegen/selection.h"
#include "codegen/tune.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "ir/context.h"
@@ -143,10 +144,148 @@ Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) {
throw std::runtime_error("unknown conversion from ir::value to Value");
}
/* lower tile to a set of llvm::Value's */
//void selection::lower_tile(ir::value *v) {
// Grid construction
std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes, IRBuilder<> &builder){
size_t dim = shapes.size();
std::vector<Value*> result(dim);
for(unsigned k = 0; k < dim - 1; k++){
Constant *dim_k = builder.getInt32(shapes[k]);
Value *rem = builder.CreateURem(trailing, dim_k);
trailing = builder.CreateUDiv(trailing, dim_k);
result[k] = rem;
}
result[dim - 1] = trailing;
return result;
}
//}
void selection::init_axes(ir::instruction *instr, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = instr->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = *params_->get_param(instr, "p0.d" + str_i);
warp_size[i] = *params_->get_param(instr, "p1.d" + str_i);
n_warps[i] = *params_->get_param(instr, "p2.d" + str_i);
}
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
// Create axes
std::vector<distributed_axis> axes(dim);
for(unsigned k = 0; k < dim; k++) {
Value *warp_size_k = builder.getInt32(warp_size[k]);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
thread_id = builder.CreateMul(thread_id, contiguous_k);
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset));
}
axes[k] = {idx_list};
}
// Store axes
axes_[instr] = axes;
}
void selection::create_grids(std::vector<ir::instruction*> &grids,
std::map<unsigned*, ir::instruction*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
for(unsigned shape: v->get_type()->get_tile_shapes()) {
result += (shape > 1)?shape:0;
}
return result;
};
// bind references
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
const auto& shapes = i->get_type()->get_tile_shapes();
bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i);
if(is_shared)
continue;
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] == 1)
continue;
unsigned *x = params_->get_param(i, "p0.d" + std::to_string(d));
ir::instruction *&r = references[x];
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
r = i;
}
}
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
grids.push_back(ref.second);
}
void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
// fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Function *get_thread_id = Intrinsic::getDeclaration(mod, Intrinsic::nvvm_read_ptx_sreg_tid_x);
Value *warp_size = builder.getInt32(32);
Value *u_thread_id = builder.CreateCall(get_thread_id, {});
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
std::vector<ir::instruction*> grids;
std::map<unsigned*, ir::instruction*> references;
create_grids(grids, references, fn);
for(ir::instruction* i: grids)
init_axes(i, builder, u_thread_warp_id, u_warp_id);
// create tile
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i);
const auto& shapes = i->get_type()->get_tile_shapes();
// create shared tile
if(is_shared){
tmap_.insert({i, new shared_tile(shapes)});
}
// create distributed tile
else {
const auto &shapes = i->get_type()->get_tile_shapes();
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned *x = params_->get_param(i, "p0.d" + std::to_string(d));
axes[d] = axes_.at(references.at(x))[d];
}
else
axes[d].values = {builder.getInt32(0)};
}
tmap_.insert({i, new distributed_tile(shapes, axes)});
}
}
}
void selection::lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder) {
}
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
LLVMContext &ctx = builder.getContext();
std::cout << typeid(*src).name() << " " << src->get_type()->get_type_id() << std::endl;
if(src->get_type()->is_tile_ty()) {
std::cout << "tile instruction" << std::endl;
lower_tile_instruction(src, builder);
}
else {
Instruction *i = llvm_inst(src, ctx);
vmap_[src] = i;
builder.Insert(i);
}
}
void selection::run(ir::module &src, Module &dst){
vmap_.clear();
@@ -166,14 +305,14 @@ void selection::run(ir::module &src, Module &dst){
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
bmap_[block] = dst_block;
}
// create grids
dst_builder.SetInsertPoint(bmap_[fn->blocks()[0]]);
init_grids(fn, dst_builder);
// iterate through block
for(ir::basic_block *block: fn->blocks()) {
dst_builder.SetInsertPoint(bmap_[block]);
for(ir::instruction *inst: block->get_inst_list()) {
Instruction *dst_inst = llvm_inst(inst, dst_ctx);
vmap_[inst] = dst_inst;
dst_builder.Insert(dst_inst);
}
for(ir::instruction *i: block->get_inst_list())
lower_instruction(i, dst_builder);
}
// add phi operands
for(ir::basic_block *block: fn->blocks())

View File

@@ -23,6 +23,7 @@ void tune::init_c_phi(ir::instruction *v) {
for(unsigned k = 0; k < phi->get_type()->get_tile_shapes().size(); k++)
if(dependencies_.find({op, k}) != dependencies_.end()
|| dependencies_.find({phi, k}) != dependencies_.end()){
std::cout << typeid(*op).name() << std::endl;
add_constraint({phi, k}, {op, k});
}
}
@@ -32,11 +33,12 @@ void tune::init_c_graph(ir::instruction *v) {
if(dynamic_cast<ir::reshape_inst*>(v)){
ir::value *op = v->get_operand(0);
unsigned current = 0;
for(unsigned i = 0; i < shapes.size(); i ++)
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1)
static_params_.insert({{v, i}, 1});
else
add_constraint({v, i}, {op, current++});
}
}
else if(dynamic_cast<ir::splat_inst*>(v)){
@@ -58,8 +60,9 @@ void tune::init_c_graph(ir::instruction *v) {
}
else if(dynamic_cast<ir::user*>(v)){
for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops())
for(ir::value* op: v->ops()){
add_constraint({v, i}, {op, i});
}
}
}
@@ -82,8 +85,8 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
}
}
void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
result.clear();
std::vector<unsigned*> tune::get_params(ir::module &mod) {
std::vector<unsigned *> result;
std::set<unsigned*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
@@ -92,6 +95,11 @@ void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
if(seen.insert(x.second).second && *x.second == 0){
result.push_back(x.second);
}
return result;
}
std::map<std::string, unsigned*> tune::get_params(ir::instruction* i) {
return params_.at(i);
}
void tune::run(ir::module &mod) {
@@ -117,9 +125,10 @@ void tune::run(ir::module &mod) {
}
}
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
for(ir::function *fn: mod.get_function_list()){
/* grids */
void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<unsigned*, ir::instruction*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
for(unsigned shape: v->get_type()->get_tile_shapes()) {
@@ -127,8 +136,7 @@ for(ir::function *fn: mod.get_function_list()){
}
return result;
};
using std::to_string;
std::map<unsigned*, ir::instruction*> references;
// bind references
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
@@ -137,16 +145,25 @@ for(ir::function *fn: mod.get_function_list()){
if(*param.second == 1)
continue;
ir::instruction *&r = references[param.second];
if(!r && get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
r = i;
}
}
// extract unique instructions in order
std::vector<ir::instruction*> grids;
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
grids.push_back(ref.second);
}
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
for(ir::function *fn: mod.get_function_list()){
using std::to_string;
// initialize grids
std::map<unsigned*, ir::instruction*> references;
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
// number of warps
int num_warps = 1;

View File

@@ -16,7 +16,6 @@ builder::builder(context &ctx):
//===----------------------------------------------------------------------===//
// utilities
//===----------------------------------------------------------------------===//
void builder::set_insert_point(basic_block::iterator it){
block_ = (*it)->get_parent();
insert_point_ = it;

View File

@@ -85,18 +85,24 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::value *pred = get_value(name, preds.front());
result = make_phi(pred->get_type(), 1, block);
set_value(name, block, result);
add_phi_operands(name, (ir::phi_node*&)result);
result = add_phi_operands(name, (ir::phi_node*&)result);
}
set_value(name, block, result);
return result;
}
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
if(values_.find(key) != values_.end()){
return values_.at(key);
}
return get_value_recursive(name, block);
ir::value *result = get_value_recursive(name, block);
builder_.set_insert_point(save_block);
if(save_pt != save_block->end())
builder_.set_insert_point(save_pt);
return result;
}
ir::value *module::get_value(const std::string& name) {