[test] added support for max, min reduction and made it easy to add more
This commit is contained in:
@@ -925,30 +925,47 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function
|
||||
void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
Module *module = fn->getParent();
|
||||
std::map<indices_t, Value*> partial;
|
||||
ir::value *op = x->get_operand(0);
|
||||
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
|
||||
ir::value *arg = x->get_operand(0);
|
||||
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||
ir::reduce_inst::op_t op = x->get_op();
|
||||
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: return builder.CreateAdd(x, y);
|
||||
case ir::reduce_inst::SUB: return builder.CreateSub(x, y);
|
||||
case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y);
|
||||
case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y);
|
||||
case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y);
|
||||
case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y);
|
||||
case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y);
|
||||
case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y);
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
unsigned axis = x->get_axis();
|
||||
|
||||
// reduce within thread
|
||||
op_tile->for_each([&](indices_t idx) {
|
||||
arg_tile->for_each([&](indices_t idx) {
|
||||
indices_t pidx = idx;
|
||||
pidx[axis] = builder.getInt32(0);
|
||||
Value *current = op_tile->get_value(idx);
|
||||
Value *current = arg_tile->get_value(idx);
|
||||
// current partial result is not initialized -- create
|
||||
if(partial.find(pidx) == partial.end())
|
||||
partial[pidx] = current;
|
||||
// current partial result is initialized -- accumulate
|
||||
else
|
||||
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
|
||||
partial[pidx] = accumulate(partial[pidx], current);
|
||||
});
|
||||
|
||||
// depth
|
||||
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = op_tile->axis(axis).values.size();
|
||||
unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = arg_tile->axis(axis).values.size();
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
|
||||
// shapes
|
||||
auto shared_shapes = op_tile->get_shapes();
|
||||
auto shared_shapes = arg_tile->get_shapes();
|
||||
shared_shapes[axis] = depth;
|
||||
|
||||
// reduce within blocks
|
||||
@@ -957,7 +974,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
||||
Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id;
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx[axis] = lane;
|
||||
@@ -981,7 +998,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *next = builder.CreateLoad(read_ptr);
|
||||
// accumulate
|
||||
result = builder.CreateFAdd(result, next);
|
||||
result = accumulate(result, next);
|
||||
// write back
|
||||
builder.CreateStore(result, write_ptr);
|
||||
}
|
||||
|
Reference in New Issue
Block a user