attempting vectorization

This commit is contained in:
Philippe Tillet
2019-02-10 18:29:25 -05:00
parent 4a0736ce20
commit 3d07e909c6
3 changed files with 69 additions and 14 deletions

View File

@@ -34,24 +34,52 @@ void distributed_tile::init_indices() {
}
}
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes)
: tile(ty, shapes), axes_(axes) {
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder)
: tile(ty, shapes), axes_(axes), builder_(builder), vectorized_(true) {
init_indices();
for(size_t i = 0; i < indices_.size(); i++)
values_.push_back(UndefValue::get(ty));
// vectorization
vector_size_ = 1;
if(ty->isVectorTy())
vector_size_ = ty->getVectorNumElements();
}
void distributed_tile::set_value(indices_t idx, Value *v) {
values_[indices_[idx]] = v;
unsigned value_idx = indices_[idx];
Value *&result = values_[value_idx/vector_size_*vector_size_];
if(v->getType() == result->getType()) {
assert(value_idx % vector_size_ == 0);
result = v;
}
// insert scalar in vector
else {
assert(vector_size_==1 || result->getType()->isVectorTy());
assert(v->getType()->getScalarType() == result->getType()->getScalarType());
result = builder_.CreateInsertElement(result, v, value_idx % vector_size_);
}
}
Value* distributed_tile::get_value(indices_t idx) {
return values_[indices_[idx]];
unsigned value_idx = indices_[idx];
Value *&result = values_[value_idx/vector_size_*vector_size_];
if(vectorized_ || vector_size_ == 1) {
assert(value_idx % vector_size_ == 0);
return result;
}
// extract scalar from vector
else {
assert(result->getType()->isVectorTy());
return builder_.CreateExtractElement(result, value_idx % vector_size_);
}
return result;
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(auto &idx: indices_)
fn(idx.first);
for(auto &idx: indices_) {
if(!vectorized_ || (idx.second % vector_size_ == 0))
fn(idx.first);
}
}
/* Shared Tile */
@@ -121,12 +149,23 @@ Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
if(base_ptr == nullptr)
if(base_ptr == nullptr){
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
// Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vec_);
// Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerElementType());
// base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
}
Value *ptr = builder_.CreateGEP(base_ptr, shared_offset(cst_idx));
return builder_.CreateLoad(ptr);
}
/* helper to make vector type */
llvm::Type *selection::make_vector_ty(llvm::Type *ty, size_t vector_size) {
if(vector_size == 1)
return ty;
return VectorType::get(ty, vector_size);
}
/* convert ir::type to Type */
Type *selection::llvm_type(ir::type *ty, LLVMContext &ctx) {
// function
@@ -299,7 +338,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param(v, "p0.d" + str_k)] = distributed_axis{idx_list};
axes_[params_->get_param(v, "p0.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
}
}
@@ -378,17 +417,22 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
axes[d] = axes_.at(x);
}
else
else{
axes[d].contiguous = 1;
axes[d].values = {builder.getInt32(0)};
}
}
distributed_tile *T = new distributed_tile(ty, shapes, axes);
distributed_tile *T = new distributed_tile(make_vector_ty(ty, axes[0].contiguous), shapes, axes, builder);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v))
if(dynamic_cast<ir::constant*>(v)){
T->unset_vectorized_iteration();
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
});
T->set_vectorized_iteration();
}
}
}
@@ -454,6 +498,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
in_tile->unset_vectorized_iteration();
result->for_each([&](indices_t out_idx){
indices_t in_idx;
for(size_t k = 0; k < shapes.size(); k++){
@@ -462,6 +507,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}
result->set_value(out_idx, in_tile->get_value(in_idx));
});
in_tile->set_vectorized_iteration();
}
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {