[codegen][selection] adding support for reduction along arbitrary axis

This commit is contained in:
Philippe Tillet
2019-08-02 20:56:40 -07:00
parent d9945692a9
commit 6be532c6a2
5 changed files with 73 additions and 33 deletions

View File

@@ -131,11 +131,11 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
}
Value* shared_tile::shared_offset(indices_t idx) {
Value *result = builder_.getInt32(0);
result = builder_.CreateAdd(result, idx[0]);
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[0]);
for(size_t i = 1; i < idx.size(); i++)
result = builder_.CreateAdd(result, builder_.CreateMul(idx[i], builder_.getInt32(shapes_[i-1])));
result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1])));
return result;
}
@@ -145,7 +145,7 @@ shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRB
}
void shared_tile::set_value(indices_t idx, Value *value) {
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(idx));
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx));
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
builder_.CreateStore(value, ptr);
@@ -176,7 +176,7 @@ Value* shared_tile::get_value(indices_t idx) {
// if(isa<Instruction>(non_cst_idx.front())){
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
// }
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx));
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(ty, vector_size);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -184,7 +184,7 @@ Value* shared_tile::get_value(indices_t idx) {
}
// builder_.SetInsertPoint(store);
}
Value *offset = shared_offset(cst_idx);
Value *offset = shared_offset(builder_, shapes_, cst_idx);
Value *div = offset;
if(vector_size_ > 1)
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
@@ -824,23 +824,39 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
return;
}
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
Value *partial = nullptr;
std::map<indices_t, Value*> partial;
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
size_t axis = 0;
unsigned num_warps = params_->get_num_threads() / 32;
std::vector<unsigned> shapes = op->get_shapes();
shapes.erase(shapes.begin() + axis);
if(shapes.empty())
shapes.push_back(1);
// reduce within thread
op->for_each([&](indices_t idx){
indices_t pidx = idx;
pidx.erase(pidx.begin() + axis);
if(pidx.empty())
pidx.push_back(builder.getInt32(0));
Value *current = op->get_value(idx);
if(partial == nullptr)
partial = current;
if(partial.find(pidx) == partial.end())
partial[pidx] = current;
else
partial = builder.CreateFAdd(partial, current);
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
});
// reduce within warp
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
for (int i = 16; i > 0; i >>= 1){
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), partial,
builder.getInt32(i), builder.getInt32(0x1f)});
partial = builder.CreateFAdd(partial, rhs);
for (int i = 16; i > 0; i >>= 1)
for(auto& x: partial)
{
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second,
builder.getInt32(i),
builder.getInt32(0x1f)});
x.second = builder.CreateFAdd(x.second, rhs);
}
// reduce within block
Value *tid = tgt_->get_local_id(module, builder, 0);
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
@@ -853,10 +869,15 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, warp_id);
builder.CreateStore(partial, write_ptr);
for(auto& x: partial){
Value *offset = shared_tile::shared_offset(builder, shapes, x.first);
offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0])));
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset);
builder.CreateStore(x.second, write_ptr);
}
builder.CreateBr(partial_reduce_done);
builder.SetInsertPoint(partial_reduce_done);
// Final reduction with the first warp
tgt_->add_barrier(module, builder);
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
@@ -865,11 +886,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
final_reduce_do, final_reduce_done);
builder.SetInsertPoint(final_reduce_do);
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
Value *result = builder.CreateLoad(read_ptr);
BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn);
BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn);
builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)),
read_shmem_do, read_shmem_done);
builder.SetInsertPoint(read_shmem_do);
Value *loaded= builder.CreateLoad(read_ptr);
builder.CreateBr(read_shmem_done);
builder.SetInsertPoint(read_shmem_done);
Value *result = builder.CreatePHI(loaded->getType(), 2);
((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do);
((PHINode*)result)->addIncoming(loaded, read_shmem_do);
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
Value *rhs = builder.CreateCall(shfl, {result, builder.getInt32(i),
builder.getInt32(0x1f), builder.getInt32(0xffffffff)});
builder.CreateFAdd(result, rhs);
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result,
builder.getInt32(i), builder.getInt32(0x1f)});
result = builder.CreateFAdd(result, rhs);
}
builder.CreateStore(result, read_ptr);
builder.CreateBr(final_reduce_done);