[codegen][selection] adding support for reduction along arbitrary axis

This commit is contained in:
Philippe Tillet
2019-08-02 20:56:40 -07:00
parent d9945692a9
commit 6be532c6a2
5 changed files with 73 additions and 33 deletions

View File

@@ -43,6 +43,7 @@ public:
virtual void set_value(indices_t idx, llvm::Value *v) = 0;
virtual llvm::Value* get_value(indices_t idx) = 0;
llvm::Type *get_ty() const { return ty_; }
shapes_t get_shapes() const { return shapes_; }
protected:
llvm::Type *ty_;
@@ -54,7 +55,6 @@ private:
void extract_constant(llvm::Value *arg, llvm::Value *&non_cst, llvm::Value *&cst);
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
llvm::Value* shared_offset(indices_t idx);
public:
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
@@ -65,6 +65,7 @@ public:
llvm::Value* get_value(indices_t idx);
llvm::Value* get_pointer() { return ptr_; }
llvm::Value* get_offset() { return offset_; }
static llvm::Value* shared_offset(llvm::IRBuilder<>& builder, const shapes_t& shapes, indices_t idx);
private:
llvm::Value *ptr_;

View File

@@ -131,11 +131,11 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
}
Value* shared_tile::shared_offset(indices_t idx) {
Value *result = builder_.getInt32(0);
result = builder_.CreateAdd(result, idx[0]);
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[0]);
for(size_t i = 1; i < idx.size(); i++)
result = builder_.CreateAdd(result, builder_.CreateMul(idx[i], builder_.getInt32(shapes_[i-1])));
result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1])));
return result;
}
@@ -145,7 +145,7 @@ shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRB
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(idx));
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
@@ -176,7 +176,7 @@ Value* shared_tile::get_value(indices_t idx) {
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -184,7 +184,7 @@ Value* shared_tile::get_value(indices_t idx) {
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(cst_idx);
Value *offset = shared_offset(builder_, shapes_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
@@ -824,23 +824,39 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
return;
}
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
Value *partial = nullptr;
std::map<indices_t, Value*> partial;
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
size_t axis = 0;
unsigned num_warps = params_->get_num_threads() / 32;
std::vector<unsigned> shapes = op->get_shapes();
shapes.erase(shapes.begin() + axis);
if(shapes.empty())
shapes.push_back(1);
// reduce within thread
op->for_each([&](indices_t idx){
indices_t pidx = idx;
pidx.erase(pidx.begin() + axis);
if(pidx.empty())
pidx.push_back(builder.getInt32(0));
Value *current = op->get_value(idx);
if(partial == nullptr)
partial = current;
if(partial.find(pidx) == partial.end())
partial[pidx] = current;
else
partial = builder.CreateFAdd(partial, current);
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
});
// reduce within warp
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
for (int i = 16; i > 0; i >>= 1){
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), partial,
builder.getInt32(i), builder.getInt32(0x1f)});
partial = builder.CreateFAdd(partial, rhs);
for (int i = 16; i > 0; i >>= 1)
for(auto& x: partial)
{
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second,
builder.getInt32(i),
builder.getInt32(0x1f)});
x.second = builder.CreateFAdd(x.second, rhs);
}
// reduce within block
Value *tid = tgt_->get_local_id(module, builder, 0);
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
@@ -853,10 +869,15 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, warp_id);
builder.CreateStore(partial, write_ptr);
for(auto& x: partial){
Value *offset = shared_tile::shared_offset(builder, shapes, x.first);
offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0])));
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset);
builder.CreateStore(x.second, write_ptr);
}
builder.CreateBr(partial_reduce_done);
builder.SetInsertPoint(partial_reduce_done);
// Final reduction with the first warp
tgt_->add_barrier(module, builder);
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
@@ -865,11 +886,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
final_reduce_do, final_reduce_done);
builder.SetInsertPoint(final_reduce_do);
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
Value *result = builder.CreateLoad(read_ptr);
BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn);
BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn);
builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)),
read_shmem_do, read_shmem_done);
builder.SetInsertPoint(read_shmem_do);
Value *loaded= builder.CreateLoad(read_ptr);
builder.CreateBr(read_shmem_done);
builder.SetInsertPoint(read_shmem_done);
Value *result = builder.CreatePHI(loaded->getType(), 2);
((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do);
((PHINode*)result)->addIncoming(loaded, read_shmem_do);
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
Value *rhs = builder.CreateCall(shfl, {result, builder.getInt32(i),
builder.getInt32(0x1f), builder.getInt32(0xffffffff)});
builder.CreateFAdd(result, rhs);
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result,
builder.getInt32(i), builder.getInt32(0x1f)});
result = builder.CreateFAdd(result, rhs);
}
builder.CreateStore(result, read_ptr);
builder.CreateBr(final_reduce_done);

View File

@@ -42,17 +42,25 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
}
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
if(dynamic_cast<ir::reduce_inst*>(x))
return 32;
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
if(dynamic_cast<ir::reduce_inst*>(x)){
size_t shape = 1;
if(x->get_type()->is_tile_ty()){
auto shapes = x->get_type()->get_tile_shapes();
for(auto x: shapes)
shape *= x->get_value();
}
size_t n_warps = params_->get_num_threads() / 32;
return shape * num_bytes * n_warps;
}
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
result += pad * result / ld;
num_bytes += pad * num_bytes / ld;
}
if(buffer_info_->is_double(x))
result *= 2;
return result;
num_bytes *= 2;
return num_bytes;
}
void shmem_allocation::run(){

View File

@@ -227,7 +227,7 @@ void tune::run(ir::module &mod) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
@@ -381,7 +381,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
errors[i].push_back("HMMA must have only 4 fragments per warp");
}
int num_threads = get_req_num_threads(i);
if(num_threads % 32 != 0)
if(num_threads % 64 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");

View File

@@ -219,9 +219,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
};