[code generation] fixed bug in on-the-fly AST to IR lowering
This commit is contained in:
@@ -24,21 +24,16 @@ extern translation_unit *ast_root;
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
int32 rx[16] = get_global_range[16](0);\
|
||||
int32 ry[16] = get_global_range[16](1);\
|
||||
int32 rx[32] = get_global_range[32](0);\
|
||||
int32 ry[32] = get_global_range[32](1);\
|
||||
int32 rka[8] = 0 ... 8;\
|
||||
int32 rkb[8] = 0 ... 8;\
|
||||
fp32 C[16, 16] = 0;\
|
||||
fp32 C[32, 32] = 0;\
|
||||
int32 k;\
|
||||
fp32* pa[16, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[16, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[16, 16];\
|
||||
fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[32, 32];\
|
||||
for(k = 0; k < K; k = k + 8){\
|
||||
fp32 A[16, 8] = *pa;\
|
||||
fp32 B[16, 8] = *pb;\
|
||||
C = dot(A, B, C);\
|
||||
pa = pa + 8*M;\
|
||||
pb = pb + 8*K;\
|
||||
}\
|
||||
pc = c + rx[:, newaxis] + ry[newaxis, :];\
|
||||
*pc = C;\
|
||||
@@ -60,13 +55,37 @@ int main() {
|
||||
tdl::codegen::tune tune;
|
||||
tdl::codegen::liveness liveness;
|
||||
tdl::codegen::allocation allocation(&liveness);
|
||||
tdl::codegen::selection selection(&allocation, &tune);
|
||||
tune.run(module);
|
||||
std::vector<unsigned> params = {
|
||||
// asm
|
||||
2, 16, 1,
|
||||
// bsn
|
||||
2, 16, 1,
|
||||
// pa
|
||||
1, 2, 4,
|
||||
// pb
|
||||
1, 2, 4,
|
||||
// c
|
||||
2, 16, 1, 1, 2, 4
|
||||
};
|
||||
std::map<tdl::ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
std::cout << tune.get_params(module).size() << std::endl;
|
||||
for(unsigned *x: tune.get_params(module))
|
||||
*x = params[i++];
|
||||
tune.check_constraints(module, errors);
|
||||
// std::cout << "errors: " << errors.size() << std::endl;
|
||||
// for(auto &x: errors){
|
||||
// for(auto &e: x.second)
|
||||
// std::cout << e << std::endl;
|
||||
// }
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
std::vector<unsigned*> params;
|
||||
tune.get_params(module, params);
|
||||
std::cout << params.size() << std::endl;
|
||||
selection.run(module, llvm_module);
|
||||
// std::vector<unsigned*> params = tune.get_params(module);
|
||||
// std::cout << params.size() << std::endl;
|
||||
// selection.run(module, llvm_module);
|
||||
// // print LLVM program
|
||||
// llvm::PrintModulePass print(llvm::outs());
|
||||
|
@@ -65,6 +65,7 @@ class node {
|
||||
protected:
|
||||
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
||||
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
||||
static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty);
|
||||
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
||||
public:
|
||||
|
@@ -21,32 +21,75 @@ namespace tdl{
|
||||
namespace codegen{
|
||||
|
||||
class allocation;
|
||||
class tune;
|
||||
|
||||
struct distributed_axis {
|
||||
std::vector<llvm::Value*> values;
|
||||
};
|
||||
|
||||
class tile {
|
||||
protected:
|
||||
typedef std::vector<unsigned> shapes_t;
|
||||
|
||||
public:
|
||||
tile(const shapes_t &shapes): shapes_(shapes){ }
|
||||
|
||||
private:
|
||||
shapes_t shapes_;
|
||||
};
|
||||
|
||||
class shared_tile: public tile {
|
||||
public:
|
||||
using tile::tile;
|
||||
|
||||
};
|
||||
|
||||
class distributed_tile: public tile{
|
||||
typedef std::vector<distributed_axis> axes_t;
|
||||
|
||||
public:
|
||||
distributed_tile(const shapes_t& shapes, const axes_t &axes)
|
||||
: tile(shapes), axes_(axes) {}
|
||||
|
||||
private:
|
||||
axes_t axes_;
|
||||
};
|
||||
|
||||
|
||||
class selection{
|
||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||
typedef std::map<ir::basic_block *, llvm::BasicBlock *> bmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
|
||||
private:
|
||||
// LLVM conversions
|
||||
llvm::Type* llvm_type(ir::type *ty, llvm::LLVMContext &ctx);
|
||||
llvm::Value* llvm_value(ir::value *v,llvm:: LLVMContext &ctx);
|
||||
llvm::Instruction* llvm_inst(ir::instruction *inst, llvm::LLVMContext &ctx);
|
||||
llvm::Constant* llvm_constant(ir::constant *cst, llvm::LLVMContext &ctx);
|
||||
|
||||
// grid construction
|
||||
void create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<unsigned*, ir::instruction*> &references,
|
||||
ir::function *fn);
|
||||
void init_axes(ir::instruction *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
|
||||
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder);
|
||||
|
||||
// lowering
|
||||
void lower_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
||||
void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder);
|
||||
|
||||
public:
|
||||
selection(allocation *alloc): alloc_(alloc){ }
|
||||
selection(allocation *alloc, tune *params): alloc_(alloc), params_(params){ }
|
||||
void run(ir::module &src, llvm::Module &dst);
|
||||
|
||||
private:
|
||||
vmap_t vmap_;
|
||||
bmap_t bmap_;
|
||||
tmap_t tmap_;
|
||||
allocation *alloc_;
|
||||
|
||||
tune *params_;
|
||||
std::map<ir::instruction*, std::vector<distributed_axis>> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
class function;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -24,11 +25,13 @@ private:
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
void connected_components(node_t x, const std::vector<unsigned*> vals, std::set<node_t> &nodes, graph_t &graph);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<unsigned*, ir::instruction*> &references, ir::function *fn);
|
||||
|
||||
|
||||
public:
|
||||
void get_params(ir::module& mod, std::vector<unsigned*> &result);
|
||||
unsigned *get_param(ir::value *value);
|
||||
std::vector<unsigned *> get_params(ir::module& mod);
|
||||
std::map<std::string, unsigned *> get_params(ir::instruction* i);
|
||||
unsigned *get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
|
||||
|
@@ -149,12 +149,15 @@ protected:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cast_inst: public unary_inst{
|
||||
using unary_inst::unary_inst;
|
||||
using ic = llvm::Instruction::CastOps;
|
||||
|
||||
public:
|
||||
typedef llvm::CastInst::CastOps op_t;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value *v, const std::string &name, instruction *next, op_t op)
|
||||
: unary_inst(ty, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(op_t op, value *arg, type *ty);
|
||||
|
||||
@@ -172,25 +175,26 @@ private:
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name) \
|
||||
class name : public cast_inst{ \
|
||||
friend class cast_inst; \
|
||||
using cast_inst::cast_inst; \
|
||||
};
|
||||
#define TDL_IR_DECLARE_CAST_INST_SIMPLE(name, op) \
|
||||
class name : public cast_inst{ \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(trunc_inst, llvm::Instruction::CastOps::Trunc)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(z_ext_inst, llvm::Instruction::CastOps::ZExt)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(s_ext_inst, llvm::Instruction::CastOps::SExt)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_trunc_inst, llvm::Instruction::CastOps::FPTrunc)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_ext_inst, llvm::Instruction::CastOps::FPExt)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ui_to_fp_inst, llvm::Instruction::CastOps::UIToFP)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(si_to_fp_inst, llvm::Instruction::CastOps::SIToFP)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_ui_inst, llvm::Instruction::CastOps::FPToUI)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(fp_to_si_inst, llvm::Instruction::CastOps::FPToSI)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(ptr_to_int_inst, llvm::Instruction::CastOps::PtrToInt)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(int_to_ptr_inst, llvm::Instruction::CastOps::IntToPtr)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(bit_cast_inst, llvm::Instruction::CastOps::BitCast)
|
||||
TDL_IR_DECLARE_CAST_INST_SIMPLE(addr_space_cast_inst, llvm::Instruction::CastOps::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
|
@@ -90,6 +90,11 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
|
||||
ir::value *tmp = ir::undef_value::get(ty);
|
||||
implicit_broadcast(mod, arg, tmp);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
@@ -320,7 +325,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
explicit_cast(mod->get_builder(), value, ty->get_scalar_ty());
|
||||
implicit_broadcast(mod, value, value);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
@@ -331,85 +336,85 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
/* Expression */
|
||||
/*------------------*/
|
||||
/* Binary operator */
|
||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *arg, ir::value *rhs, const std::string &name) const
|
||||
{
|
||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, lhs, rhs);
|
||||
implicit_cast(builder, arg, rhs, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, arg, rhs);
|
||||
if(op_==MUL && is_float)
|
||||
return builder.create_fmul(lhs, rhs, name);
|
||||
return builder.create_fmul(arg, rhs, name);
|
||||
if(op_==MUL && is_int)
|
||||
return builder.create_mul(lhs, rhs, name);
|
||||
return builder.create_mul(arg, rhs, name);
|
||||
if(op_==DIV && is_float)
|
||||
return builder.create_fdiv(lhs, rhs, name);
|
||||
return builder.create_fdiv(arg, rhs, name);
|
||||
if(op_==DIV && is_int && is_signed)
|
||||
return builder.create_sdiv(lhs, rhs, name);
|
||||
return builder.create_sdiv(arg, rhs, name);
|
||||
if(op_==DIV && is_int && !is_signed)
|
||||
return builder.create_udiv(lhs, rhs, name);
|
||||
return builder.create_udiv(arg, rhs, name);
|
||||
if(op_==MOD && is_float)
|
||||
return builder.create_frem(lhs, rhs, name);
|
||||
return builder.create_frem(arg, rhs, name);
|
||||
if(op_==MOD && is_int && is_signed)
|
||||
return builder.create_srem(lhs, rhs, name);
|
||||
return builder.create_srem(arg, rhs, name);
|
||||
if(op_==MOD && is_int && !is_signed)
|
||||
return builder.create_urem(lhs, rhs, name);
|
||||
return builder.create_urem(arg, rhs, name);
|
||||
if(op_==ADD && is_float)
|
||||
return builder.create_fadd(lhs, rhs, name);
|
||||
return builder.create_fadd(arg, rhs, name);
|
||||
if(op_==ADD && is_int)
|
||||
return builder.create_add(lhs, rhs);
|
||||
return builder.create_add(arg, rhs);
|
||||
if(op_==ADD && is_ptr)
|
||||
return builder.create_gep(lhs, {rhs});
|
||||
return builder.create_gep(arg, {rhs});
|
||||
if(op_==SUB && is_float)
|
||||
return builder.create_fsub(lhs, rhs, name);
|
||||
return builder.create_fsub(arg, rhs, name);
|
||||
if(op_==SUB && is_int)
|
||||
return builder.create_sub(lhs, rhs, name);
|
||||
return builder.create_sub(arg, rhs, name);
|
||||
if(op_==SUB && is_ptr)
|
||||
return builder.create_gep(lhs, {builder.create_neg(rhs)});
|
||||
return builder.create_gep(arg, {builder.create_neg(rhs)});
|
||||
if(op_==LEFT_SHIFT)
|
||||
return builder.create_shl(lhs, rhs, name);
|
||||
return builder.create_shl(arg, rhs, name);
|
||||
if(op_==RIGHT_SHIFT)
|
||||
return builder.create_ashr(lhs, rhs, name);
|
||||
return builder.create_ashr(arg, rhs, name);
|
||||
if(op_ == LT && is_float)
|
||||
return builder.create_fcmpOLT(lhs, rhs, name);
|
||||
return builder.create_fcmpOLT(arg, rhs, name);
|
||||
if(op_ == LT && is_int && is_signed)
|
||||
return builder.create_icmpSLT(lhs, rhs, name);
|
||||
return builder.create_icmpSLT(arg, rhs, name);
|
||||
if(op_ == LT && is_int && !is_signed)
|
||||
return builder.create_icmpULT(lhs, rhs, name);
|
||||
return builder.create_icmpULT(arg, rhs, name);
|
||||
if(op_ == GT && is_float)
|
||||
return builder.create_fcmpOGT(lhs, rhs, name);
|
||||
return builder.create_fcmpOGT(arg, rhs, name);
|
||||
if(op_ == GT && is_int && is_signed)
|
||||
return builder.create_icmpSGT(lhs, rhs, name);
|
||||
return builder.create_icmpSGT(arg, rhs, name);
|
||||
if(op_ == GT && is_int && !is_signed)
|
||||
return builder.create_icmpUGT(lhs, rhs, name);
|
||||
return builder.create_icmpUGT(arg, rhs, name);
|
||||
if(op_ == LE && is_float)
|
||||
return builder.create_fcmpOLE(lhs, rhs, name);
|
||||
return builder.create_fcmpOLE(arg, rhs, name);
|
||||
if(op_ == LE && is_int && is_signed)
|
||||
return builder.create_icmpSLE(lhs, rhs, name);
|
||||
return builder.create_icmpSLE(arg, rhs, name);
|
||||
if(op_ == LE && is_int && !is_signed)
|
||||
return builder.create_icmpULE(lhs, rhs, name);
|
||||
return builder.create_icmpULE(arg, rhs, name);
|
||||
if(op_ == GE && is_float)
|
||||
return builder.create_fcmpOGE(lhs, rhs, name);
|
||||
return builder.create_fcmpOGE(arg, rhs, name);
|
||||
if(op_ == GE && is_int && is_signed)
|
||||
return builder.create_icmpSGE(lhs, rhs, name);
|
||||
return builder.create_icmpSGE(arg, rhs, name);
|
||||
if(op_ == GE && is_int && !is_signed)
|
||||
return builder.create_icmpUGE(lhs, rhs, name);
|
||||
return builder.create_icmpUGE(arg, rhs, name);
|
||||
if(op_ == EQ && is_float)
|
||||
return builder.create_fcmpOEQ(lhs, rhs, name);
|
||||
return builder.create_fcmpOEQ(arg, rhs, name);
|
||||
if(op_ == EQ && is_int)
|
||||
return builder.create_icmpEQ(lhs, rhs, name);
|
||||
return builder.create_icmpEQ(arg, rhs, name);
|
||||
if(op_ == NE && is_float)
|
||||
return builder.create_fcmpONE(lhs, rhs, name);
|
||||
return builder.create_fcmpONE(arg, rhs, name);
|
||||
if(op_ == NE && is_int)
|
||||
return builder.create_icmpNE(lhs, rhs, name);
|
||||
return builder.create_icmpNE(arg, rhs, name);
|
||||
if(op_ == AND)
|
||||
return builder.create_and(lhs, rhs, name);
|
||||
return builder.create_and(arg, rhs, name);
|
||||
if(op_ == XOR)
|
||||
return builder.create_xor(lhs, rhs, name);
|
||||
return builder.create_xor(arg, rhs, name);
|
||||
if(op_ == OR)
|
||||
return builder.create_or(lhs, rhs, name);
|
||||
return builder.create_or(arg, rhs, name);
|
||||
if(op_ == LAND)
|
||||
return builder.create_and(lhs, rhs, name);
|
||||
return builder.create_and(arg, rhs, name);
|
||||
if(op_ == LOR)
|
||||
return builder.create_or(lhs, rhs, name);
|
||||
return builder.create_or(arg, rhs, name);
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
@@ -433,6 +438,12 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value *A = A_->codegen(mod);
|
||||
ir::value *B = B_->codegen(mod);
|
||||
ir::value *C = C_->codegen(mod);
|
||||
// unsigned M = A->get_type()->get_tile_shapes()[0];
|
||||
// unsigned N = B->get_type()->get_tile_shapes()[1];
|
||||
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
|
||||
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
||||
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
||||
// implicit_broadcast(mod, tmp, C);
|
||||
return mod->get_builder().create_matmul(A, B, C);
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "codegen/selection.h"
|
||||
#include "codegen/tune.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "ir/context.h"
|
||||
@@ -143,10 +144,148 @@ Value* selection::llvm_value(ir::value *v, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::value to Value");
|
||||
}
|
||||
|
||||
/* lower tile to a set of llvm::Value's */
|
||||
//void selection::lower_tile(ir::value *v) {
|
||||
// Grid construction
|
||||
std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[k]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[k] = rem;
|
||||
}
|
||||
result[dim - 1] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
//}
|
||||
void selection::init_axes(ir::instruction *instr, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = instr->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = *params_->get_param(instr, "p0.d" + str_i);
|
||||
warp_size[i] = *params_->get_param(instr, "p1.d" + str_i);
|
||||
n_warps[i] = *params_->get_param(instr, "p2.d" + str_i);
|
||||
}
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
// Create axes
|
||||
std::vector<distributed_axis> axes(dim);
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
||||
thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset));
|
||||
}
|
||||
axes[k] = {idx_list};
|
||||
}
|
||||
// Store axes
|
||||
axes_[instr] = axes;
|
||||
}
|
||||
|
||||
void selection::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<unsigned*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
||||
result += (shape > 1)?shape:0;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
// bind references
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
const auto& shapes = i->get_type()->get_tile_shapes();
|
||||
bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i);
|
||||
if(is_shared)
|
||||
continue;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] == 1)
|
||||
continue;
|
||||
unsigned *x = params_->get_param(i, "p0.d" + std::to_string(d));
|
||||
ir::instruction *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
r = i;
|
||||
}
|
||||
}
|
||||
// create grid
|
||||
for(auto &ref: references)
|
||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
||||
grids.push_back(ref.second);
|
||||
}
|
||||
|
||||
void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
|
||||
// fetch linear ID
|
||||
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
|
||||
Function *get_thread_id = Intrinsic::getDeclaration(mod, Intrinsic::nvvm_read_ptx_sreg_tid_x);
|
||||
Value *warp_size = builder.getInt32(32);
|
||||
Value *u_thread_id = builder.CreateCall(get_thread_id, {});
|
||||
Value *u_thread_warp_id = builder.CreateURem(u_thread_id, warp_size);
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
std::vector<ir::instruction*> grids;
|
||||
std::map<unsigned*, ir::instruction*> references;
|
||||
create_grids(grids, references, fn);
|
||||
for(ir::instruction* i: grids)
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
// create tile
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i);
|
||||
const auto& shapes = i->get_type()->get_tile_shapes();
|
||||
// create shared tile
|
||||
if(is_shared){
|
||||
tmap_.insert({i, new shared_tile(shapes)});
|
||||
}
|
||||
// create distributed tile
|
||||
else {
|
||||
const auto &shapes = i->get_type()->get_tile_shapes();
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned *x = params_->get_param(i, "p0.d" + std::to_string(d));
|
||||
axes[d] = axes_.at(references.at(x))[d];
|
||||
}
|
||||
else
|
||||
axes[d].values = {builder.getInt32(0)};
|
||||
}
|
||||
tmap_.insert({i, new distributed_tile(shapes, axes)});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void selection::lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder) {
|
||||
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
LLVMContext &ctx = builder.getContext();
|
||||
std::cout << typeid(*src).name() << " " << src->get_type()->get_type_id() << std::endl;
|
||||
if(src->get_type()->is_tile_ty()) {
|
||||
std::cout << "tile instruction" << std::endl;
|
||||
lower_tile_instruction(src, builder);
|
||||
}
|
||||
else {
|
||||
Instruction *i = llvm_inst(src, ctx);
|
||||
vmap_[src] = i;
|
||||
builder.Insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
void selection::run(ir::module &src, Module &dst){
|
||||
vmap_.clear();
|
||||
@@ -166,14 +305,14 @@ void selection::run(ir::module &src, Module &dst){
|
||||
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
|
||||
bmap_[block] = dst_block;
|
||||
}
|
||||
// create grids
|
||||
dst_builder.SetInsertPoint(bmap_[fn->blocks()[0]]);
|
||||
init_grids(fn, dst_builder);
|
||||
// iterate through block
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
dst_builder.SetInsertPoint(bmap_[block]);
|
||||
for(ir::instruction *inst: block->get_inst_list()) {
|
||||
Instruction *dst_inst = llvm_inst(inst, dst_ctx);
|
||||
vmap_[inst] = dst_inst;
|
||||
dst_builder.Insert(dst_inst);
|
||||
}
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
lower_instruction(i, dst_builder);
|
||||
}
|
||||
// add phi operands
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
|
@@ -23,6 +23,7 @@ void tune::init_c_phi(ir::instruction *v) {
|
||||
for(unsigned k = 0; k < phi->get_type()->get_tile_shapes().size(); k++)
|
||||
if(dependencies_.find({op, k}) != dependencies_.end()
|
||||
|| dependencies_.find({phi, k}) != dependencies_.end()){
|
||||
std::cout << typeid(*op).name() << std::endl;
|
||||
add_constraint({phi, k}, {op, k});
|
||||
}
|
||||
}
|
||||
@@ -32,11 +33,12 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++)
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
else
|
||||
add_constraint({v, i}, {op, current++});
|
||||
}
|
||||
}
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
|
||||
@@ -58,8 +60,9 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
else if(dynamic_cast<ir::user*>(v)){
|
||||
for(unsigned i = 0; i < shapes.size(); i ++)
|
||||
for(ir::value* op: v->ops())
|
||||
for(ir::value* op: v->ops()){
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +85,8 @@ void tune::connected_components(node_t x, const std::vector<unsigned *> vals, st
|
||||
}
|
||||
}
|
||||
|
||||
void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
|
||||
result.clear();
|
||||
std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||
std::vector<unsigned *> result;
|
||||
std::set<unsigned*> seen;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -92,6 +95,11 @@ void tune::get_params(ir::module &mod, std::vector<unsigned *> &result) {
|
||||
if(seen.insert(x.second).second && *x.second == 0){
|
||||
result.push_back(x.second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, unsigned*> tune::get_params(ir::instruction* i) {
|
||||
return params_.at(i);
|
||||
}
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
@@ -117,9 +125,10 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
}
|
||||
|
||||
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
/* grids */
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<unsigned*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
unsigned result = 0;
|
||||
for(unsigned shape: v->get_type()->get_tile_shapes()) {
|
||||
@@ -127,8 +136,7 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
}
|
||||
return result;
|
||||
};
|
||||
using std::to_string;
|
||||
std::map<unsigned*, ir::instruction*> references;
|
||||
// bind references
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
@@ -137,16 +145,25 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
if(*param.second == 1)
|
||||
continue;
|
||||
ir::instruction *&r = references[param.second];
|
||||
if(!r && get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
r = i;
|
||||
}
|
||||
}
|
||||
|
||||
// extract unique instructions in order
|
||||
std::vector<ir::instruction*> grids;
|
||||
// create grid
|
||||
for(auto &ref: references)
|
||||
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
|
||||
grids.push_back(ref.second);
|
||||
}
|
||||
|
||||
|
||||
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
using std::to_string;
|
||||
|
||||
// initialize grids
|
||||
std::map<unsigned*, ir::instruction*> references;
|
||||
std::vector<ir::instruction*> grids;
|
||||
create_grids(grids, references, fn);
|
||||
|
||||
// number of warps
|
||||
int num_warps = 1;
|
||||
|
@@ -16,7 +16,6 @@ builder::builder(context &ctx):
|
||||
//===----------------------------------------------------------------------===//
|
||||
// utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void builder::set_insert_point(basic_block::iterator it){
|
||||
block_ = (*it)->get_parent();
|
||||
insert_point_ = it;
|
||||
|
@@ -85,18 +85,24 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
ir::value *pred = get_value(name, preds.front());
|
||||
result = make_phi(pred->get_type(), 1, block);
|
||||
set_value(name, block, result);
|
||||
add_phi_operands(name, (ir::phi_node*&)result);
|
||||
result = add_phi_operands(name, (ir::phi_node*&)result);
|
||||
}
|
||||
set_value(name, block, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::basic_block* save_block = builder_.get_insert_block();
|
||||
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
||||
val_key_t key(name, block);
|
||||
if(values_.find(key) != values_.end()){
|
||||
return values_.at(key);
|
||||
}
|
||||
return get_value_recursive(name, block);
|
||||
ir::value *result = get_value_recursive(name, block);
|
||||
builder_.set_insert_point(save_block);
|
||||
if(save_pt != save_block->end())
|
||||
builder_.set_insert_point(save_pt);
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name) {
|
||||
|
Reference in New Issue
Block a user