[code generation]: more progress for instruction selection

This commit is contained in:
Philippe Tillet
2019-01-26 02:05:56 -05:00
parent e2de27dfe2
commit e522b06be2
8 changed files with 91 additions and 34 deletions

View File

@@ -32,10 +32,7 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
int32 k;\ int32 k;\
fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\ fp32* pa[32, 8] = a + rx[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\ fp32* pb[32, 8] = b + ry[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[32, 32];\ fp32* pc[32, 32] = c + rx[:, newaxis] + ry[newaxis, :]*M;\
for(k = 0; k < K; k = k + 8){\
}\
pc = c + rx[:, newaxis] + ry[newaxis, :];\
*pc = C;\ *pc = C;\
}\ }\
"; ";
@@ -59,15 +56,13 @@ int main() {
tune.run(module); tune.run(module);
std::vector<unsigned> params = { std::vector<unsigned> params = {
// asm // asm
2, 16, 1, 2, 8, 1,
// bsn // bsn
2, 16, 1, 4, 4, 1,
// pa // pa
1, 2, 4, 2, 4, 1,
// pb // pb
1, 2, 4, 1, 8, 1,
// c
2, 16, 1, 1, 2, 4
}; };
std::map<tdl::ir::value*, std::vector<std::string>> errors; std::map<tdl::ir::value*, std::vector<std::string>> errors;
unsigned i = 0; unsigned i = 0;
@@ -75,11 +70,11 @@ int main() {
for(unsigned *x: tune.get_params(module)) for(unsigned *x: tune.get_params(module))
*x = params[i++]; *x = params[i++];
tune.check_constraints(module, errors); tune.check_constraints(module, errors);
// std::cout << "errors: " << errors.size() << std::endl; std::cout << "errors: " << errors.size() << std::endl;
// for(auto &x: errors){ for(auto &x: errors){
// for(auto &e: x.second) for(auto &e: x.second)
// std::cout << e << std::endl; std::cout << e << std::endl;
// } }
shared.run(module); shared.run(module);
liveness.run(module); liveness.run(module);
allocation.run(); allocation.run();

View File

@@ -114,9 +114,15 @@ public:
const slice_enum_t type_; const slice_enum_t type_;
}; };
class named_expression;
class expression: public node{ class expression: public node{
public: public:
virtual ir::value* codegen(ir::module *) const = 0; virtual ir::value* codegen(ir::module *) const = 0;
named_expression *lvalue() const { return lvalue_; }
protected:
named_expression *lvalue_;
}; };
class postfix_expression: public expression{ class postfix_expression: public expression{
@@ -163,10 +169,9 @@ private:
const list<slice*>* slices_; const list<slice*>* slices_;
}; };
class named_expression: public expression { class named_expression: public expression {
public: public:
named_expression(node *id): id_((const identifier*)id) { } named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; }
const identifier *id() const { return id_; } const identifier *id() const { return id_; }
ir::value* codegen(ir::module * mod) const; ir::value* codegen(ir::module * mod) const;
@@ -228,7 +233,10 @@ private:
public: public:
unary_operator(UNARY_OP_T op, node *arg) unary_operator(UNARY_OP_T op, node *arg)
: op_(op), : op_(op),
arg_((expression*)arg) { } arg_((expression*)arg) {
if(op == DEREF)
this->lvalue_ = arg_->lvalue();
}
UNARY_OP_T get_op() const { return op_; } UNARY_OP_T get_op() const { return op_; }
ir::value* codegen(ir::module *mod) const; ir::value* codegen(ir::module *mod) const;

View File

@@ -32,9 +32,10 @@ protected:
typedef std::vector<unsigned> shapes_t; typedef std::vector<unsigned> shapes_t;
public: public:
tile(const shapes_t &shapes): shapes_(shapes){ } tile(llvm::Type *ty, const shapes_t &shapes): shapes_(shapes){ }
private: private:
llvm::Type *ty_;
shapes_t shapes_; shapes_t shapes_;
}; };
@@ -46,13 +47,20 @@ public:
class distributed_tile: public tile{ class distributed_tile: public tile{
typedef std::vector<distributed_axis> axes_t; typedef std::vector<distributed_axis> axes_t;
typedef std::vector<llvm::Value*> indices_t;
typedef std::map<indices_t, unsigned> indices_map_t;
typedef std::vector<llvm::Value*> values_t;
private:
void init_indices();
public: public:
distributed_tile(const shapes_t& shapes, const axes_t &axes) distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes);
: tile(shapes), axes_(axes) {}
private: private:
axes_t axes_; axes_t axes_;
indices_map_t indices_;
values_t values_;
}; };

View File

@@ -26,6 +26,8 @@ public:
const basic_block *get_parent() const { return parent_; } const basic_block *get_parent() const { return parent_; }
basic_block *get_parent() { return parent_; } basic_block *get_parent() { return parent_; }
void erase_from_parent(); void erase_from_parent();
// helpers
bool has_tile_result_or_op();
private: private:
basic_block *parent_; basic_block *parent_;

View File

@@ -523,7 +523,8 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
mod->set_value(x->id()->name(), rvalue); mod->set_value(x->id()->name(), rvalue);
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){ else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
assert(x->get_op()==DEREF); assert(x->get_op()==DEREF);
ir::value *ptr = x->codegen(mod); assert(x->lvalue());
ir::value *ptr = x->lvalue()->codegen(mod);
mod->get_builder().create_store(ptr, rvalue); mod->get_builder().create_store(ptr, rvalue);
} }
return rvalue; return rvalue;

View File

@@ -13,6 +13,33 @@ namespace codegen{
using namespace llvm; using namespace llvm;
/* Distributed Tile */
void distributed_tile::init_indices() {
std::vector<size_t> id(axes_.size(), 0);
size_t k = 0;
while(true) {
indices_t current;
for(size_t d = 0; d < id.size(); d++)
current.push_back(axes_[d].values[id[d]]);
indices_[current] = indices_.size();
id[0]++;
while(id[k] == axes_[k].values.size()){
if(k == id.size() - 1)
return;
id[k++] = 0;
id[k]++;
}
k = 0;
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes)
: tile(ty, shapes), axes_(axes) {
init_indices();
for(size_t i = 0; i < indices_.size(); i++)
values_.push_back(UndefValue::get(ty_));
}
/* convert ir::type to Type */ /* convert ir::type to Type */
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
@@ -186,7 +213,7 @@ void selection::init_axes(ir::instruction *instr, IRBuilder<> &builder, Value *u
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset)); idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset));
} }
axes[k] = {idx_list}; axes[k] = distributed_axis{idx_list};
} }
// Store axes // Store axes
axes_[instr] = axes; axes_[instr] = axes;
@@ -230,6 +257,7 @@ void selection::create_grids(std::vector<ir::instruction*> &grids,
void selection::init_grids(ir::function *fn, IRBuilder<> &builder){ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
// fetch linear ID // fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent(); Module *mod = builder.GetInsertBlock()->getParent()->getParent();
LLVMContext &ctx = builder.getContext();
Function *get_thread_id = Intrinsic::getDeclaration(mod, Intrinsic::nvvm_read_ptx_sreg_tid_x); Function *get_thread_id = Intrinsic::getDeclaration(mod, Intrinsic::nvvm_read_ptx_sreg_tid_x);
Value *warp_size = builder.getInt32(32); Value *warp_size = builder.getInt32(32);
Value *u_thread_id = builder.CreateCall(get_thread_id, {}); Value *u_thread_id = builder.CreateCall(get_thread_id, {});
@@ -248,9 +276,10 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
continue; continue;
bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i); bool is_shared = dynamic_cast<ir::copy_to_shared_inst*>(i);
const auto& shapes = i->get_type()->get_tile_shapes(); const auto& shapes = i->get_type()->get_tile_shapes();
Type* ty = llvm_type(i->get_type(), ctx);
// create shared tile // create shared tile
if(is_shared){ if(is_shared){
tmap_.insert({i, new shared_tile(shapes)}); tmap_.insert({i, new shared_tile(ty, shapes)});
} }
// create distributed tile // create distributed tile
else { else {
@@ -264,20 +293,18 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder){
else else
axes[d].values = {builder.getInt32(0)}; axes[d].values = {builder.getInt32(0)};
} }
tmap_.insert({i, new distributed_tile(shapes, axes)}); tmap_.insert({i, new distributed_tile(ty, shapes, axes)});
} }
} }
} }
void selection::lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder) { void selection::lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder) {
std::cout << typeid(*src).name() << std::endl;
} }
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) { void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
LLVMContext &ctx = builder.getContext(); LLVMContext &ctx = builder.getContext();
std::cout << typeid(*src).name() << " " << src->get_type()->get_type_id() << std::endl; if(src->has_tile_result_or_op()) {
if(src->get_type()->is_tile_ty()) {
std::cout << "tile instruction" << std::endl;
lower_tile_instruction(src, builder); lower_tile_instruction(src, builder);
} }
else { else {

View File

@@ -29,7 +29,13 @@ void tune::init_c_phi(ir::instruction *v) {
} }
void tune::init_c_graph(ir::instruction *v) { void tune::init_c_graph(ir::instruction *v) {
const auto& shapes = v->get_type()->get_tile_shapes(); // Reference shape
std::vector<unsigned> shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
if(dynamic_cast<ir::reshape_inst*>(v)){ if(dynamic_cast<ir::reshape_inst*>(v)){
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
unsigned current = 0; unsigned current = 0;
@@ -40,9 +46,11 @@ void tune::init_c_graph(ir::instruction *v) {
add_constraint({v, i}, {op, current++}); add_constraint({v, i}, {op, current++});
} }
} }
// Splat
else if(dynamic_cast<ir::splat_inst*>(v)){ else if(dynamic_cast<ir::splat_inst*>(v)){
} }
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){ else if(dynamic_cast<ir::broadcast_inst*>(v)){
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
ir::type *op_ty = op->get_type(); ir::type *op_ty = op->get_type();
@@ -51,13 +59,14 @@ void tune::init_c_graph(ir::instruction *v) {
if(op_shapes[i] == shapes[i] && v != op) if(op_shapes[i] == shapes[i] && v != op)
add_constraint({v, i}, {op, i}); add_constraint({v, i}, {op, i});
} }
} }
// Matrix multiplication
else if(dynamic_cast<ir::matmul_inst*>(v)){ else if(dynamic_cast<ir::matmul_inst*>(v)){
ir::value *D = v->get_operand(2); ir::value *D = v->get_operand(2);
add_constraint({v, 0}, {D, 0}); add_constraint({v, 0}, {D, 0});
add_constraint({v, 1}, {D, 1}); add_constraint({v, 1}, {D, 1});
} }
// Element-wise
else if(dynamic_cast<ir::user*>(v)){ else if(dynamic_cast<ir::user*>(v)){
for(unsigned i = 0; i < shapes.size(); i ++) for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops()){ for(ir::value* op: v->ops()){
@@ -102,18 +111,19 @@ std::map<std::string, unsigned*> tune::get_params(ir::instruction* i) {
return params_.at(i); return params_.at(i);
} }
void tune::run(ir::module &mod) { void tune::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
// Build constraints graph // Build constraints graph
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()) for(ir::instruction *i : block->get_inst_list())
if(i->get_type()->is_tile_ty()){ if(i->has_tile_result_or_op()){
init_c_graph(i); init_c_graph(i);
} }
// Build phi constraints // Build phi constraints
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()) for(ir::instruction *i : block->get_inst_list())
if(i->get_type()->is_tile_ty()) if(i->has_tile_result_or_op())
init_c_phi(i); init_c_phi(i);
// Layout parameters // Layout parameters
while(!nodes_.empty()){ while(!nodes_.empty()){

View File

@@ -25,6 +25,12 @@ void instruction::erase_from_parent() {
parent_->erase(this); parent_->erase(this);
} }
bool instruction::has_tile_result_or_op() {
bool result = get_type()->is_tile_ty();
for(ir::value *v: ops())
result |= v->get_type()->is_tile_ty();
return result;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// phi_node classes // phi_node classes