diff --git a/include/codegen/selection.h b/include/codegen/selection.h index 73c72f120..729c36adb 100644 --- a/include/codegen/selection.h +++ b/include/codegen/selection.h @@ -70,23 +70,40 @@ private: void init_indices(); public: - distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes, llvm::IRBuilder<> &builder); - void set_vectorized_iteration() { vectorized_ = true; } - void unset_vectorized_iteration() { vectorized_ = false; } - void set_value(indices_t idx, llvm::Value *v); + distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes); + virtual void for_each(std::function fn) = 0; + +protected: + axes_t axes_; + indices_map_t indices_; + values_t values_; +}; + +class serialized_distributed_tile: public distributed_tile { +public: + using distributed_tile::distributed_tile; + +public: + void set_value(indices_t, llvm::Value *); + llvm::Value* get_value(indices_t idx); + void for_each(std::function fn); +}; + +class vectorized_distributed_tile: public distributed_tile { +private: + llvm::Type *make_vector_ty(llvm::Type *ty, size_t vector_size); + +public: + vectorized_distributed_tile(llvm::Type *ty, const shapes_t& shapes, const axes_t &axes, llvm::IRBuilder<> &builder); + void set_value(indices_t, llvm::Value *); llvm::Value* get_value(indices_t idx); void for_each(std::function fn); private: - axes_t axes_; - indices_map_t indices_; - values_t values_; - size_t vector_size_; llvm::IRBuilder<> &builder_; - bool vectorized_; + size_t vector_size_; }; - class selection{ typedef std::map vmap_t; typedef std::map tmap_t; diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 501e25f49..2510f89e2 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -34,17 +34,43 @@ void distributed_tile::init_indices() { } } -distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder) - : tile(ty, shapes), axes_(axes), builder_(builder), vectorized_(true) { + +distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes) + : tile(ty, shapes), axes_(axes) { init_indices(); for(size_t i = 0; i < indices_.size(); i++) - values_.push_back(UndefValue::get(ty)); - // vectorization - vector_size_ = 1; - if(ty->isVectorTy()) - vector_size_ = ty->getVectorNumElements(); + values_.push_back(UndefValue::get(ty_)); } +/* Serialized distributed tile */ +void serialized_distributed_tile::set_value(indices_t idx, Value *v) { + values_[indices_[idx]] = v; +} + +void serialized_distributed_tile::get_value(indices_t idx) { + return values_[indices_[idx]]; +} + +void serialized_distributed_tile::for_each(std::function fn) { + for(auto &idx: indices_) + fn(idx.first); +} + +/* Vectorized distributed tile */ +llvm::Type *vectorized_distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) { + if(vector_size == 1) + return ty; + return VectorType::get(ty, vector_size); +} + +vectorized_distributed_tile::vectorized_distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder) + : distributed_tile(make_vector_ty(ty, axes[0].contiguous), shapes), axes_(axes), builder_(builder) { + vector_size_ = 1; + if(ty_->isVectorTy()) + vector_size_ = ty_->getVectorNumElements(); +} + + void distributed_tile::set_value(indices_t idx, Value *v) { unsigned value_idx = indices_[idx]; Value *&result = values_[value_idx/vector_size_*vector_size_]; @@ -54,6 +80,7 @@ void distributed_tile::set_value(indices_t idx, Value *v) { } // insert scalar in vector else { + std::cout << v->getType()->getScalarType()->getTypeID() << " " << result->getType()->getScalarType()->getTypeID() << std::endl; assert(vector_size_==1 || result->getType()->isVectorTy()); assert(v->getType()->getScalarType() == result->getType()->getScalarType()); result = builder_.CreateInsertElement(result, v, value_idx % vector_size_); @@ -63,7 +90,7 @@ void distributed_tile::set_value(indices_t idx, Value *v) { Value* distributed_tile::get_value(indices_t idx) { unsigned value_idx = indices_[idx]; Value *&result = values_[value_idx/vector_size_*vector_size_]; - if(vectorized_ || vector_size_ == 1) { + if(vectorize_ || vector_size_ == 1) { assert(value_idx % vector_size_ == 0); return result; } @@ -77,7 +104,7 @@ Value* distributed_tile::get_value(indices_t idx) { void distributed_tile::for_each(std::function fn) { for(auto &idx: indices_) { - if(!vectorized_ || (idx.second % vector_size_ == 0)) + if(!vectorize_ || (idx.second % vector_size_ == 0)) fn(idx.first); } } @@ -142,6 +169,8 @@ shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRB void shared_tile::set_value(indices_t idx, Value *value) { Value *ptr = builder_.CreateGEP(ptr_, shared_offset(idx)); + unsigned addr_space = ptr->getType()->getPointerAddressSpace(); + ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); builder_.CreateStore(value, ptr); } @@ -159,13 +188,6 @@ Value* shared_tile::get_value(indices_t idx) { return builder_.CreateLoad(ptr); } -/* helper to make vector type */ -llvm::Type *selection::make_vector_ty(llvm::Type *ty, size_t vector_size) { - if(vector_size == 1) - return ty; - return VectorType::get(ty, vector_size); -} - /* convert ir::type to Type */ Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) { // function @@ -422,16 +444,15 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, axes[d].values = {builder.getInt32(0)}; } } - distributed_tile *T = new distributed_tile(make_vector_ty(ty, axes[0].contiguous), shapes, axes, builder); + bool vectorize = dynamic_cast(v); + distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); tmap_.insert({v, T}); // constant range if(dynamic_cast(v)){ - T->unset_vectorized_iteration(); T->for_each([&](indices_t idx){ assert(idx.size() == 1); T->set_value(idx, idx[0]); }); - T->set_vectorized_iteration(); } } @@ -498,7 +519,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & else if(dynamic_cast(ins)) { ir::value* in = ins->get_operand(0); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); - in_tile->unset_vectorized_iteration(); result->for_each([&](indices_t out_idx){ indices_t in_idx; for(size_t k = 0; k < shapes.size(); k++){ @@ -507,7 +527,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & } result->set_value(out_idx, in_tile->get_value(in_idx)); }); - in_tile->set_vectorized_iteration(); } // splat else if(dynamic_cast(ins)) {