[tests] basic test for reduction in python passes

This commit is contained in:
Philippe Tillet
2019-09-11 17:35:56 -04:00
parent 2781cdcf93
commit 04a0fbd8e3
10 changed files with 120 additions and 22 deletions

View File

@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
unsigned per_thread = op_tile->axis(axis).values.size();
unsigned depth = shape_ax / per_thread;
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));