[codegen][selection] adding support for reduction along arbitrary axis
This commit is contained in:
@@ -43,6 +43,7 @@ public:
|
||||
virtual void set_value(indices_t idx, llvm::Value *v) = 0;
|
||||
virtual llvm::Value* get_value(indices_t idx) = 0;
|
||||
llvm::Type *get_ty() const { return ty_; }
|
||||
shapes_t get_shapes() const { return shapes_; }
|
||||
|
||||
protected:
|
||||
llvm::Type *ty_;
|
||||
@@ -54,7 +55,6 @@ private:
|
||||
void extract_constant(llvm::Value *arg, llvm::Value *&non_cst, llvm::Value *&cst);
|
||||
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
||||
|
||||
llvm::Value* shared_offset(indices_t idx);
|
||||
|
||||
public:
|
||||
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
|
||||
@@ -65,6 +65,7 @@ public:
|
||||
llvm::Value* get_value(indices_t idx);
|
||||
llvm::Value* get_pointer() { return ptr_; }
|
||||
llvm::Value* get_offset() { return offset_; }
|
||||
static llvm::Value* shared_offset(llvm::IRBuilder<>& builder, const shapes_t& shapes, indices_t idx);
|
||||
|
||||
private:
|
||||
llvm::Value *ptr_;
|
||||
|
@@ -131,11 +131,11 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(indices_t idx) {
|
||||
Value *result = builder_.getInt32(0);
|
||||
result = builder_.CreateAdd(result, idx[0]);
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
|
||||
Value *result = builder.getInt32(0);
|
||||
result = builder.CreateAdd(result, idx[0]);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
result = builder_.CreateAdd(result, builder_.CreateMul(idx[i], builder_.getInt32(shapes_[i-1])));
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1])));
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRB
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(idx));
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
@@ -176,7 +176,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
@@ -184,7 +184,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(cst_idx);
|
||||
Value *offset = shared_offset(builder_, shapes_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
@@ -824,23 +824,39 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
return;
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
|
||||
Value *partial = nullptr;
|
||||
std::map<indices_t, Value*> partial;
|
||||
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
size_t axis = 0;
|
||||
unsigned num_warps = params_->get_num_threads() / 32;
|
||||
std::vector<unsigned> shapes = op->get_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
if(shapes.empty())
|
||||
shapes.push_back(1);
|
||||
|
||||
// reduce within thread
|
||||
op->for_each([&](indices_t idx){
|
||||
indices_t pidx = idx;
|
||||
pidx.erase(pidx.begin() + axis);
|
||||
if(pidx.empty())
|
||||
pidx.push_back(builder.getInt32(0));
|
||||
Value *current = op->get_value(idx);
|
||||
if(partial == nullptr)
|
||||
partial = current;
|
||||
if(partial.find(pidx) == partial.end())
|
||||
partial[pidx] = current;
|
||||
else
|
||||
partial = builder.CreateFAdd(partial, current);
|
||||
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
|
||||
});
|
||||
|
||||
// reduce within warp
|
||||
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
|
||||
for (int i = 16; i > 0; i >>= 1){
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), partial,
|
||||
builder.getInt32(i), builder.getInt32(0x1f)});
|
||||
partial = builder.CreateFAdd(partial, rhs);
|
||||
for (int i = 16; i > 0; i >>= 1)
|
||||
for(auto& x: partial)
|
||||
{
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second,
|
||||
builder.getInt32(i),
|
||||
builder.getInt32(0x1f)});
|
||||
x.second = builder.CreateFAdd(x.second, rhs);
|
||||
}
|
||||
|
||||
// reduce within block
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
|
||||
@@ -853,10 +869,15 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
|
||||
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
|
||||
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, warp_id);
|
||||
builder.CreateStore(partial, write_ptr);
|
||||
for(auto& x: partial){
|
||||
Value *offset = shared_tile::shared_offset(builder, shapes, x.first);
|
||||
offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0])));
|
||||
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset);
|
||||
builder.CreateStore(x.second, write_ptr);
|
||||
}
|
||||
builder.CreateBr(partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_done);
|
||||
|
||||
// Final reduction with the first warp
|
||||
tgt_->add_barrier(module, builder);
|
||||
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
|
||||
@@ -865,11 +886,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
final_reduce_do, final_reduce_done);
|
||||
builder.SetInsertPoint(final_reduce_do);
|
||||
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
|
||||
Value *result = builder.CreateLoad(read_ptr);
|
||||
BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn);
|
||||
BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn);
|
||||
builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)),
|
||||
read_shmem_do, read_shmem_done);
|
||||
builder.SetInsertPoint(read_shmem_do);
|
||||
Value *loaded= builder.CreateLoad(read_ptr);
|
||||
builder.CreateBr(read_shmem_done);
|
||||
builder.SetInsertPoint(read_shmem_done);
|
||||
Value *result = builder.CreatePHI(loaded->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do);
|
||||
((PHINode*)result)->addIncoming(loaded, read_shmem_do);
|
||||
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
|
||||
Value *rhs = builder.CreateCall(shfl, {result, builder.getInt32(i),
|
||||
builder.getInt32(0x1f), builder.getInt32(0xffffffff)});
|
||||
builder.CreateFAdd(result, rhs);
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result,
|
||||
builder.getInt32(i), builder.getInt32(0x1f)});
|
||||
result = builder.CreateFAdd(result, rhs);
|
||||
}
|
||||
builder.CreateStore(result, read_ptr);
|
||||
builder.CreateBr(final_reduce_done);
|
||||
|
@@ -42,17 +42,25 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
if(dynamic_cast<ir::reduce_inst*>(x))
|
||||
return 32;
|
||||
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(dynamic_cast<ir::reduce_inst*>(x)){
|
||||
size_t shape = 1;
|
||||
if(x->get_type()->is_tile_ty()){
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
for(auto x: shapes)
|
||||
shape *= x->get_value();
|
||||
}
|
||||
size_t n_warps = params_->get_num_threads() / 32;
|
||||
return shape * num_bytes * n_warps;
|
||||
}
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();
|
||||
result += pad * result / ld;
|
||||
num_bytes += pad * num_bytes / ld;
|
||||
}
|
||||
if(buffer_info_->is_double(x))
|
||||
result *= 2;
|
||||
return result;
|
||||
num_bytes *= 2;
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
void shmem_allocation::run(){
|
||||
|
@@ -227,7 +227,7 @@ void tune::run(ir::module &mod) {
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
@@ -381,7 +381,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
errors[i].push_back("HMMA must have only 4 fragments per warp");
|
||||
}
|
||||
int num_threads = get_req_num_threads(i);
|
||||
if(num_threads % 32 != 0)
|
||||
if(num_threads % 64 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
|
@@ -219,9 +219,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user