[tests] basic test for reduction in python passes
This commit is contained in:
@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = op_tile->axis(axis).values.size();
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
|
Reference in New Issue
Block a user