[test] added support for max, min reduction and made it easy to add more
This commit is contained in:
@@ -136,7 +136,7 @@ public:
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
|
@@ -611,19 +611,28 @@ public:
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
private:
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN,
|
||||
FADD, FSUB, FMAX, FMIN
|
||||
};
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
static std::string to_str(op_t op);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "red<" + std::to_string(axis_) + ">"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
|
@@ -131,6 +131,8 @@ public:
|
||||
|
||||
// TILE ARITHMETICS BEGIN
|
||||
NEWAXIS,
|
||||
MAX,
|
||||
MIN,
|
||||
// TILE ARITHMETICS END
|
||||
|
||||
ALIGNAS, // _Alignas
|
||||
|
@@ -60,15 +60,6 @@ void grids::init_c_graph(ir::instruction *v) {
|
||||
else if(dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else if(dynamic_cast<ir::reduce_inst*>(v)) {
|
||||
// unsigned axis = reduce->get_axis();
|
||||
// ir::value *arg = reduce->get_operand(0);
|
||||
// auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
// unsigned current = 0;
|
||||
// for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
// if(i == axis)
|
||||
// continue;
|
||||
// add_constraint({reduce, current++}, {arg, i});
|
||||
// }
|
||||
return;
|
||||
}
|
||||
else
|
||||
@@ -305,7 +296,6 @@ void grids::run(ir::module &mod) {
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value();
|
||||
std::cout << shapes[d] << " " << params_.at(i).at("mts.d" + str_d)->get_value() << " " << params_.at(i).at("nts.d" + str_d)->get_value() << std::endl;
|
||||
}
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
|
@@ -925,30 +925,47 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function
|
||||
void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
Module *module = fn->getParent();
|
||||
std::map<indices_t, Value*> partial;
|
||||
ir::value *op = x->get_operand(0);
|
||||
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
|
||||
ir::value *arg = x->get_operand(0);
|
||||
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||
ir::reduce_inst::op_t op = x->get_op();
|
||||
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||
switch(op) {
|
||||
case ir::reduce_inst::ADD: return builder.CreateAdd(x, y);
|
||||
case ir::reduce_inst::SUB: return builder.CreateSub(x, y);
|
||||
case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y);
|
||||
case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y);
|
||||
case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y);
|
||||
case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y);
|
||||
case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y);
|
||||
case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y);
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
unsigned axis = x->get_axis();
|
||||
|
||||
// reduce within thread
|
||||
op_tile->for_each([&](indices_t idx) {
|
||||
arg_tile->for_each([&](indices_t idx) {
|
||||
indices_t pidx = idx;
|
||||
pidx[axis] = builder.getInt32(0);
|
||||
Value *current = op_tile->get_value(idx);
|
||||
Value *current = arg_tile->get_value(idx);
|
||||
// current partial result is not initialized -- create
|
||||
if(partial.find(pidx) == partial.end())
|
||||
partial[pidx] = current;
|
||||
// current partial result is initialized -- accumulate
|
||||
else
|
||||
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
|
||||
partial[pidx] = accumulate(partial[pidx], current);
|
||||
});
|
||||
|
||||
// depth
|
||||
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = op_tile->axis(axis).values.size();
|
||||
unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = arg_tile->axis(axis).values.size();
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
|
||||
// shapes
|
||||
auto shared_shapes = op_tile->get_shapes();
|
||||
auto shared_shapes = arg_tile->get_shapes();
|
||||
shared_shapes[axis] = depth;
|
||||
|
||||
// reduce within blocks
|
||||
@@ -957,7 +974,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
||||
Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id;
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx[axis] = lane;
|
||||
@@ -981,7 +998,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *next = builder.CreateLoad(read_ptr);
|
||||
// accumulate
|
||||
result = builder.CreateFAdd(result, next);
|
||||
result = accumulate(result, next);
|
||||
// write back
|
||||
builder.CreateStore(result, write_ptr);
|
||||
}
|
||||
|
@@ -323,8 +323,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
|
||||
return insert(sqrt_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, axis, name));
|
||||
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, op, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
||||
|
@@ -615,6 +615,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string reduce_inst::to_str(op_t op) {
|
||||
switch (op) {
|
||||
case ADD: return "+";
|
||||
case SUB: return "-";
|
||||
case MAX: return "imax";
|
||||
case MIN: return "imin";
|
||||
case FADD: return "+";
|
||||
case FSUB: return "-";
|
||||
case FMAX: return "fmax";
|
||||
case FMIN: return "fmin";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
@@ -625,14 +642,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
return tile_type::get(scalar_ty, shapes);
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
|
||||
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
|
||||
op_(op),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, axis, name, next);
|
||||
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, op, axis, name, next);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -154,12 +154,24 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
error_not_implemented();
|
||||
}
|
||||
|
||||
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||
using ir::reduce_inst;
|
||||
switch(tag){
|
||||
case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD;
|
||||
case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB;
|
||||
case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX;
|
||||
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
|
||||
default: break;
|
||||
}
|
||||
should_not_happen();
|
||||
return reduce_inst::op_t();
|
||||
}
|
||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
|
||||
// recursion
|
||||
Visit(unary->operand_);
|
||||
ir::value* op = ret_;
|
||||
|
||||
ir::value* arg = ret_;
|
||||
ir::type *arg_ty = arg->get_type();
|
||||
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
||||
// return
|
||||
switch (unary->op_) {
|
||||
case Token::PREFIX_INC: return error_not_implemented();
|
||||
@@ -167,17 +179,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
case Token::POSTFIX_INC: return error_not_implemented();
|
||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
||||
case Token::ADDR: return error_not_implemented();
|
||||
case Token::DEREF: return set_ret(bld_->create_load(op));
|
||||
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||
case Token::PLUS: return error_not_implemented();
|
||||
case Token::MINUS: return error_not_implemented();
|
||||
case '~': return set_ret(bld_->create_neg(op));
|
||||
case '!': return set_ret(bld_->create_not(op));
|
||||
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
||||
case '^': return set_ret(bld_->create_trans(op));
|
||||
case '~': return set_ret(bld_->create_neg(arg));
|
||||
case '!': return set_ret(bld_->create_not(arg));
|
||||
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case '^': return set_ret(bld_->create_trans(arg));
|
||||
case Token::REDUCE: {
|
||||
int ax, tag;
|
||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||
return set_ret(bld_->create_reduce(op, ax));
|
||||
bool is_float = arg_scal_ty->is_floating_point_ty();
|
||||
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
|
||||
return set_ret(bld_->create_reduce(arg, op, ax));
|
||||
}
|
||||
default: error_not_implemented();
|
||||
}
|
||||
|
@@ -466,7 +466,9 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
break;
|
||||
|
||||
case Token::ADD:
|
||||
case Token::SUB:{
|
||||
case Token::SUB:
|
||||
case Token::MAX:
|
||||
case Token::MIN:{
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
|
@@ -54,6 +54,8 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
||||
{ "_Noreturn", Token::NORETURN },
|
||||
{ "_Static_assert", Token::STATIC_ASSERT },
|
||||
{ "_Thread_local", Token::THREAD },
|
||||
{ "max", Token::MAX },
|
||||
{ "min", Token::MIN },
|
||||
};
|
||||
|
||||
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
||||
|
@@ -157,6 +157,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
for(auto it: opt_space_.defines)
|
||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||
cpp.Process(tokens);
|
||||
// tokens.Print(stdout);
|
||||
// parse
|
||||
Parser parser(tokens);
|
||||
parser.Parse();
|
||||
@@ -200,7 +201,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
ir::print(module, std::cout);
|
||||
// ir::print(module, std::cout);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
|
@@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16),
|
||||
int rm[TM] = ridm * TM + 0 ... TM;
|
||||
int rn[TN] = ridn * TN + 0 ... TN;
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||
TYPE* py[TY] = Y + rm;
|
||||
TYPE* py[TY] = Y + RY;
|
||||
*py = (*px)[RED];
|
||||
}
|
||||
)";
|
||||
|
@@ -43,6 +43,34 @@ void init_zeros(std::vector<T>& x) {
|
||||
x[i] = 0;
|
||||
}
|
||||
|
||||
enum reduce_op_t {
|
||||
ADD,
|
||||
MAX,
|
||||
MIN
|
||||
};
|
||||
|
||||
std::string to_str(reduce_op_t op) {
|
||||
switch (op) {
|
||||
case ADD: return "+";
|
||||
case MAX: return "max";
|
||||
case MIN: return "min";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
template<class T>
|
||||
std::function<T(T,T)> get_accumulator(reduce_op_t op) {
|
||||
switch (op) {
|
||||
case ADD: return [](T x, T y) { return x + y; };
|
||||
case MAX: return [](T x, T y) { return std::max(x, y); };
|
||||
case MIN: return [](T x, T y) { return std::min(x, y); };
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return std::function<T(T,T)>();
|
||||
}
|
||||
|
||||
|
||||
namespace aux{
|
||||
@@ -70,6 +98,23 @@ auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
|
||||
return os << ")";
|
||||
}
|
||||
|
||||
template<class Ch, class Tr, class T>
|
||||
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, const std::vector<T>& vec) {
|
||||
os << "{";
|
||||
for(size_t i = 0; i < vec.size(); i++){
|
||||
if(i > 0)
|
||||
os << ", ";
|
||||
os << vec[i];
|
||||
}
|
||||
os << "}";
|
||||
return os;
|
||||
}
|
||||
|
||||
template<class Ch, class Tr>
|
||||
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, reduce_op_t op) {
|
||||
return os << to_str(op);
|
||||
}
|
||||
|
||||
|
||||
namespace testing {
|
||||
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
@@ -40,58 +41,66 @@ int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void reduce_nd(std::vector<T> &y, const std::vector<T> &x, size_t axis, const std::vector<int>& shapes) {
|
||||
void reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
|
||||
assert(axis <= shapes.size() - 1);
|
||||
// remove shape at index axis to get outer dimensions
|
||||
std::vector<int> outer = shapes;
|
||||
outer.erase(outer.begin() + axis);
|
||||
// retrieve shape at index axis to get inner dimension
|
||||
int inner = shapes[axis];
|
||||
// accumualtion function
|
||||
auto acc = get_accumulator<T>(op);
|
||||
// iterate over outer dimensions
|
||||
_loop_nest(outer, [&](const std::vector<int>& y_idx) {
|
||||
T acc = 0;
|
||||
T ret = 0;
|
||||
auto x_idx = y_idx;
|
||||
x_idx.insert(x_idx.begin() + axis, 0);
|
||||
// accumulate over inner dimensions
|
||||
for(int z = 0; z < inner; z++){
|
||||
x_idx[axis] = z;
|
||||
acc = acc + x[offset(x_idx, shapes)];
|
||||
ret = acc(ret, x[offset(x_idx, shapes)]);
|
||||
}
|
||||
y[offset(y_idx, outer)] = acc;
|
||||
y[offset(y_idx, outer)] = ret;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
||||
bool do_test(drv::stream* stream, std::vector<int> shape, int axis, reduce_op_t op, int nwarp){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
std::vector<NumericT> hy(M);
|
||||
std::vector<NumericT> ry(M);
|
||||
std::vector<NumericT> hx(M*N);
|
||||
size_t axy = (axis == 0) ? 1 : 0;
|
||||
std::string RY = (axis == 0) ? "rn" : "rm";
|
||||
std::vector<NumericT> hy(shape[axy]);
|
||||
std::vector<NumericT> ry(shape[axy]);
|
||||
std::vector<NumericT> hx(shape[0]*shape[1]);
|
||||
srand(0);
|
||||
init_zeros(hy);
|
||||
init_rand(hx);
|
||||
for(int i = 0; i < M; i++)
|
||||
for(int j = 0; j < N; j++)
|
||||
hx[i + j*M] = i+j;
|
||||
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
|
||||
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
|
||||
stream->write(&*dy, true, 0, hy);
|
||||
stream->write(&*dx, true, 0, hx);
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"TM", {std::to_string(M)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(N)}});
|
||||
opt.defines.push_back({"TY", {std::to_string(M)}});
|
||||
opt.defines.push_back({"RED", {"+, :"}});
|
||||
opt.defines.push_back({"TM", {std::to_string(shape[0])}});
|
||||
opt.defines.push_back({"TN", {std::to_string(shape[1])}});
|
||||
opt.defines.push_back({"TY", {std::to_string(shape[axy])}});
|
||||
opt.defines.push_back({"RY", {RY}});
|
||||
std::string RED = "";
|
||||
for(int n = 0; n < 2; n++){
|
||||
if(n > 0)
|
||||
RED += ", ";
|
||||
RED += (n==axis) ? to_str(op) : ":";
|
||||
}
|
||||
opt.defines.push_back({"RED", {RED}});
|
||||
opt.num_warps = {nwarp};
|
||||
rt::function function(src::reduce2d, opt);
|
||||
function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream);
|
||||
function({&*dx, &*dy, shape[0], shape[1], shape[0]}, grid2d(shape[0], shape[1]), stream);
|
||||
stream->synchronize();
|
||||
stream->read(&*dy, true, 0, hy);
|
||||
reduce_nd(ry, hx, 0, {M, N});
|
||||
reduce_nd(ry, hx, op, axis, shape);
|
||||
return testing::diff(hy, ry);
|
||||
}
|
||||
|
||||
@@ -100,17 +109,21 @@ int main() {
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<int, int, std::string> config_t;
|
||||
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
config_t{32, 32, "+"}
|
||||
config_t{{32, 32}, 0, MAX},
|
||||
config_t{{32, 32}, 1, ADD},
|
||||
config_t{{32, 64}, 0, ADD},
|
||||
config_t{{64, 32}, 1, ADD}
|
||||
};
|
||||
// does the work
|
||||
int M, N;
|
||||
std::string op;
|
||||
int axis;
|
||||
std::vector<int> shape;
|
||||
reduce_op_t op;
|
||||
for(const auto& c: configs){
|
||||
std::tie(M, N, op) = c;
|
||||
std::tie(shape, axis, op) = c;
|
||||
std::cout << "Testing " << c << " ... " << std::flush;
|
||||
if(do_test(stream, M, N, op, 1))
|
||||
if(do_test(stream, shape, axis, op, 1))
|
||||
std::cout << " Pass! " << std::endl;
|
||||
else
|
||||
std::cout << " Fail! " << std::endl;
|
||||
|
Reference in New Issue
Block a user