Merge branch 'c-reduction'
This commit is contained in:
@@ -136,7 +136,7 @@ public:
|
|||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
|
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
|
||||||
value *create_sqrt(value *A, const std::string &name = "");
|
value *create_sqrt(value *A, const std::string &name = "");
|
||||||
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
|
||||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||||
|
@@ -611,19 +611,28 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class reduce_inst: public builtin_inst {
|
class reduce_inst: public builtin_inst {
|
||||||
|
public:
|
||||||
|
enum op_t{
|
||||||
|
ADD, SUB, MAX, MIN,
|
||||||
|
FADD, FSUB, FMAX, FMIN
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static type* get_res_type(value *arg, unsigned axis);
|
static type* get_res_type(value *arg, unsigned axis);
|
||||||
|
static std::string to_str(op_t op);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
|
||||||
std::string repr_impl() const { return "reduce"; }
|
std::string repr_impl() const { return "red<" + std::to_string(axis_) + ">"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||||
unsigned get_axis() const { return axis_; }
|
unsigned get_axis() const { return axis_; }
|
||||||
|
op_t get_op() const { return op_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned axis_;
|
unsigned axis_;
|
||||||
|
op_t op_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class select_inst: public builtin_inst {
|
class select_inst: public builtin_inst {
|
||||||
|
@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
|
|||||||
friend class LValAssigner;
|
friend class LValAssigner;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
|
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
|
||||||
virtual ~UnaryOp() {}
|
virtual ~UnaryOp() {}
|
||||||
virtual void Accept(Visitor* v);
|
virtual void Accept(Visitor* v);
|
||||||
virtual bool IsLVal();
|
virtual bool IsLVal();
|
||||||
::Type *Convert();
|
::Type *Convert();
|
||||||
|
static int encodeRed(int ax, int tag);
|
||||||
|
static void decodeRed(int info, int& ax, int& tag);
|
||||||
void TypeChecking();
|
void TypeChecking();
|
||||||
void IncDecOpTypeChecking();
|
void IncDecOpTypeChecking();
|
||||||
void AddrOpTypeChecking();
|
void AddrOpTypeChecking();
|
||||||
void DerefOpTypeChecking();
|
void DerefOpTypeChecking();
|
||||||
|
void ReduceOpTypeChecking();
|
||||||
void TransOpTypeChecking();
|
void TransOpTypeChecking();
|
||||||
void UnaryArithmOpTypeChecking();
|
void UnaryArithmOpTypeChecking();
|
||||||
void CastOpTypeChecking();
|
void CastOpTypeChecking();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
UnaryOp(int op, Expr* operand, QualType type=nullptr)
|
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||||
: Expr(operand->Tok(), type), op_(op) {
|
: Expr(operand->Tok(), type), op_(op), info_(info) {
|
||||||
operand_ = operand;
|
operand_ = operand;
|
||||||
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
||||||
operand_ = MayCast(operand);
|
operand_ = MayCast(operand);
|
||||||
@@ -441,6 +444,7 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
int op_;
|
int op_;
|
||||||
|
int info_;
|
||||||
Expr* operand_;
|
Expr* operand_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -131,6 +131,8 @@ public:
|
|||||||
|
|
||||||
// TILE ARITHMETICS BEGIN
|
// TILE ARITHMETICS BEGIN
|
||||||
NEWAXIS,
|
NEWAXIS,
|
||||||
|
MAX,
|
||||||
|
MIN,
|
||||||
// TILE ARITHMETICS END
|
// TILE ARITHMETICS END
|
||||||
|
|
||||||
ALIGNAS, // _Alignas
|
ALIGNAS, // _Alignas
|
||||||
@@ -180,6 +182,7 @@ public:
|
|||||||
PLUS,
|
PLUS,
|
||||||
MINUS,
|
MINUS,
|
||||||
CAST,
|
CAST,
|
||||||
|
REDUCE,
|
||||||
|
|
||||||
// For preprocessor
|
// For preprocessor
|
||||||
PP_IF,
|
PP_IF,
|
||||||
|
@@ -70,7 +70,7 @@ public:
|
|||||||
struct options_space_t {
|
struct options_space_t {
|
||||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||||
std::vector<define_t> defines;
|
std::vector<define_t> defines;
|
||||||
std::vector<size_t> num_warps;
|
std::vector<int> num_warps;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct options_t {
|
struct options_t {
|
||||||
|
@@ -59,16 +59,7 @@ void grids::init_c_graph(ir::instruction *v) {
|
|||||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||||
else if(dynamic_cast<ir::downcast_inst*>(v))
|
else if(dynamic_cast<ir::downcast_inst*>(v))
|
||||||
return;
|
return;
|
||||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
|
else if(dynamic_cast<ir::reduce_inst*>(v)) {
|
||||||
unsigned axis = reduce->get_axis();
|
|
||||||
ir::value *arg = reduce->get_operand(0);
|
|
||||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
|
||||||
unsigned current = 0;
|
|
||||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
|
||||||
if(i == axis)
|
|
||||||
continue;
|
|
||||||
add_constraint({reduce, current++}, {arg, i});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -244,7 +235,6 @@ void grids::run(ir::module &mod) {
|
|||||||
unsigned size = i->get_type()->get_tile_num_elements();
|
unsigned size = i->get_type()->get_tile_num_elements();
|
||||||
/* HMMA parameters*/
|
/* HMMA parameters*/
|
||||||
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){
|
||||||
|
|
||||||
/* fragments per warp */
|
/* fragments per warp */
|
||||||
// try to make things as square as possible to maximize data re-use
|
// try to make things as square as possible to maximize data re-use
|
||||||
std::vector<unsigned> fpw = {1, 1, 1};
|
std::vector<unsigned> fpw = {1, 1, 1};
|
||||||
@@ -285,7 +275,6 @@ void grids::run(ir::module &mod) {
|
|||||||
|
|
||||||
if(num_warps_ != effective_num_warps)
|
if(num_warps_ != effective_num_warps)
|
||||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Scan-line */
|
/* Scan-line */
|
||||||
|
@@ -923,52 +923,74 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function
|
|||||||
}
|
}
|
||||||
|
|
||||||
void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||||
ir::instruction *ins = (ir::instruction*)x;
|
|
||||||
Module *module = fn->getParent();
|
Module *module = fn->getParent();
|
||||||
std::map<indices_t, Value*> partial;
|
std::map<indices_t, Value*> partial;
|
||||||
ir::value *op = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
|
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||||
|
ir::reduce_inst::op_t op = x->get_op();
|
||||||
|
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||||
|
switch(op) {
|
||||||
|
case ir::reduce_inst::ADD: return builder.CreateAdd(x, y);
|
||||||
|
case ir::reduce_inst::SUB: return builder.CreateSub(x, y);
|
||||||
|
case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y);
|
||||||
|
case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y);
|
||||||
|
case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y);
|
||||||
|
case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y);
|
||||||
|
case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y);
|
||||||
|
case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y);
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
unsigned axis = x->get_axis();
|
unsigned axis = x->get_axis();
|
||||||
|
|
||||||
// reduce within thread
|
// reduce within thread
|
||||||
op_tile->for_each([&](indices_t idx) {
|
arg_tile->for_each([&](indices_t idx) {
|
||||||
indices_t pidx = idx;
|
indices_t pidx = idx;
|
||||||
pidx.erase(pidx.begin() + axis);
|
pidx[axis] = builder.getInt32(0);
|
||||||
Value *current = op_tile->get_value(idx);
|
Value *current = arg_tile->get_value(idx);
|
||||||
// current partial result is not initialized -- create
|
// current partial result is not initialized -- create
|
||||||
if(partial.find(pidx) == partial.end())
|
if(partial.find(pidx) == partial.end())
|
||||||
partial[pidx] = current;
|
partial[pidx] = current;
|
||||||
// current partial result is initialized -- accumulate
|
// current partial result is initialized -- accumulate
|
||||||
else
|
else
|
||||||
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
|
partial[pidx] = accumulate(partial[pidx], current);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// depth
|
||||||
|
unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
|
||||||
|
unsigned per_thread = arg_tile->axis(axis).values.size();
|
||||||
|
unsigned depth = shape_ax / per_thread;
|
||||||
|
|
||||||
|
// shapes
|
||||||
|
auto shared_shapes = arg_tile->get_shapes();
|
||||||
|
shared_shapes[axis] = depth;
|
||||||
|
|
||||||
// reduce within blocks
|
// reduce within blocks
|
||||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||||
Type *res_ty = builder.getFloatTy();
|
Type *res_ty = builder.getFloatTy();
|
||||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||||
for(auto& x: partial) {
|
for(auto& x: partial) {
|
||||||
// current element being computed
|
// current element being computed
|
||||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id;
|
||||||
Value *&result = x.second;
|
Value *&result = x.second;
|
||||||
indices_t write_idx = x.first;
|
indices_t write_idx = x.first;
|
||||||
write_idx.insert(write_idx.begin() + axis, lane);
|
write_idx[axis] = lane;
|
||||||
|
|
||||||
// shared memory write pointer
|
// shared memory write pointer
|
||||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
|
Value *write_offset = shared_tile::shared_offset(builder, shared_shapes, write_idx);
|
||||||
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
||||||
|
|
||||||
// initialize shared memory
|
// initialize shared memory
|
||||||
tgt_->add_barrier(module, builder);
|
tgt_->add_barrier(module, builder);
|
||||||
builder.CreateStore(result, write_ptr);
|
builder.CreateStore(result, write_ptr);
|
||||||
// build result
|
// build result
|
||||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
|
||||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||||
// current indices
|
// current indices
|
||||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||||
current[axis] = builder.getInt32(i);
|
current[axis] = builder.getInt32(i);
|
||||||
// shared memory offset
|
// shared memory offset
|
||||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
|
Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, current);
|
||||||
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
||||||
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
||||||
// shared memory read pointer
|
// shared memory read pointer
|
||||||
@@ -976,25 +998,21 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
tgt_->add_barrier(module, builder);
|
tgt_->add_barrier(module, builder);
|
||||||
Value *next = builder.CreateLoad(read_ptr);
|
Value *next = builder.CreateLoad(read_ptr);
|
||||||
// accumulate
|
// accumulate
|
||||||
result = builder.CreateFAdd(result, next);
|
result = accumulate(result, next);
|
||||||
// write back
|
// write back
|
||||||
builder.CreateStore(result, write_ptr);
|
builder.CreateStore(result, write_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// result is on the first lane of shared memory
|
|
||||||
indices_t final = write_idx;
|
|
||||||
final[axis] = builder.getInt32(0);
|
|
||||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
|
|
||||||
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
|
||||||
tgt_->add_barrier(module, builder);
|
|
||||||
result = builder.CreateLoad(read_ptr);
|
|
||||||
if(tmap_.find(ins) == tmap_.end())
|
|
||||||
vmap_[ins] = result;
|
|
||||||
else{
|
|
||||||
distributed_tile *ti = (distributed_tile*)tmap_[ins];
|
|
||||||
ti->set_value(x.first, result);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
tgt_->add_barrier(module, builder);
|
||||||
|
|
||||||
|
distributed_tile* x_tile = (distributed_tile*)tmap_.at(x);
|
||||||
|
x_tile->for_each([&](indices_t idx) {
|
||||||
|
indices_t red_idx = idx;
|
||||||
|
red_idx.insert(red_idx.begin() + axis, builder.getInt32(0));
|
||||||
|
Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, red_idx);
|
||||||
|
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
||||||
|
x_tile->set_value(idx, builder.CreateLoad(read_ptr));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||||
|
@@ -323,8 +323,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
|
|||||||
return insert(sqrt_inst::create(A, name));
|
return insert(sqrt_inst::create(A, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
|
value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) {
|
||||||
return insert(reduce_inst::create(A, axis, name));
|
return insert(reduce_inst::create(A, op, axis, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
||||||
|
@@ -615,6 +615,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// reduce instructions
|
// reduce instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
std::string reduce_inst::to_str(op_t op) {
|
||||||
|
switch (op) {
|
||||||
|
case ADD: return "+";
|
||||||
|
case SUB: return "-";
|
||||||
|
case MAX: return "imax";
|
||||||
|
case MIN: return "imin";
|
||||||
|
case FADD: return "+";
|
||||||
|
case FSUB: return "-";
|
||||||
|
case FMAX: return "fmax";
|
||||||
|
case FMIN: return "fmin";
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||||
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
|
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
|
||||||
shapes.erase(shapes.begin() + axis);
|
shapes.erase(shapes.begin() + axis);
|
||||||
@@ -625,14 +642,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
|||||||
return tile_type::get(scalar_ty, shapes);
|
return tile_type::get(scalar_ty, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
|
reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next)
|
||||||
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
|
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
|
||||||
|
op_(op),
|
||||||
axis_(axis){
|
axis_(axis){
|
||||||
set_operand(0, arg);
|
set_operand(0, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
|
instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) {
|
||||||
return new reduce_inst(arg, axis, name, next);
|
return new reduce_inst(arg, op, axis, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::MaskedDerefOpTypeChecking() {
|
void BinaryOp::MaskedDerefOpTypeChecking() {
|
||||||
|
// auto lhsTileType = lhs_->Type()->ToTile();
|
||||||
|
// auto rhsTileType = rhs_->Type()->ToTile();
|
||||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
auto lhsType = lhsScalType->ToArithm();
|
auto lhsType = lhsScalType->ToArithm();
|
||||||
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
|
|||||||
* Unary Operators
|
* Unary Operators
|
||||||
*/
|
*/
|
||||||
|
|
||||||
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
|
||||||
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
|
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
|
||||||
ret->pool_ = &unaryOpPool;
|
ret->pool_ = &unaryOpPool;
|
||||||
|
|
||||||
ret->TypeChecking();
|
ret->TypeChecking();
|
||||||
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int UnaryOp::encodeRed(int ax, int tag) {
|
||||||
|
int result = 0;
|
||||||
|
result |= ax;
|
||||||
|
result |= tag << 16;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
|
||||||
|
ax = info & 0x0000FFFF;
|
||||||
|
tag = (info & 0xFFFF0000) >> 16;
|
||||||
|
}
|
||||||
|
|
||||||
bool UnaryOp::IsLVal() {
|
bool UnaryOp::IsLVal() {
|
||||||
// Only deref('*') could be lvalue;
|
// Only deref('*') could be lvalue;
|
||||||
return op_ == Token::DEREF;
|
return op_ == Token::DEREF;
|
||||||
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
|
|||||||
case '^':
|
case '^':
|
||||||
return TransOpTypeChecking();
|
return TransOpTypeChecking();
|
||||||
|
|
||||||
|
case Token::REDUCE:
|
||||||
|
return ReduceOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
|
|||||||
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UnaryOp::ReduceOpTypeChecking() {
|
||||||
|
int ax, tag;
|
||||||
|
decodeRed(info_, ax, tag);
|
||||||
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
if(!tileType)
|
||||||
|
Error(this, "array expected for reduction operation");
|
||||||
|
auto shape = tileType->Shape();
|
||||||
|
shape.erase(shape.begin() + ax);
|
||||||
|
type_ = TileType::New(shape, tileType->Derived());
|
||||||
|
}
|
||||||
|
|
||||||
void UnaryOp::TransOpTypeChecking() {
|
void UnaryOp::TransOpTypeChecking() {
|
||||||
auto tileType = operand_->Type()->ToTile();
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
@@ -154,12 +154,24 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
error_not_implemented();
|
error_not_implemented();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||||
|
using ir::reduce_inst;
|
||||||
|
switch(tag){
|
||||||
|
case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD;
|
||||||
|
case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB;
|
||||||
|
case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX;
|
||||||
|
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
should_not_happen();
|
||||||
|
return reduce_inst::op_t();
|
||||||
|
}
|
||||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||||
|
|
||||||
// recursion
|
// recursion
|
||||||
Visit(unary->operand_);
|
Visit(unary->operand_);
|
||||||
ir::value* op = ret_;
|
ir::value* arg = ret_;
|
||||||
|
ir::type *arg_ty = arg->get_type();
|
||||||
|
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
||||||
// return
|
// return
|
||||||
switch (unary->op_) {
|
switch (unary->op_) {
|
||||||
case Token::PREFIX_INC: return error_not_implemented();
|
case Token::PREFIX_INC: return error_not_implemented();
|
||||||
@@ -167,13 +179,20 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
case Token::POSTFIX_INC: return error_not_implemented();
|
case Token::POSTFIX_INC: return error_not_implemented();
|
||||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
case Token::POSTFIX_DEC: return error_not_implemented();
|
||||||
case Token::ADDR: return error_not_implemented();
|
case Token::ADDR: return error_not_implemented();
|
||||||
case Token::DEREF: return set_ret(bld_->create_load(op));
|
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||||
case Token::PLUS: return error_not_implemented();
|
case Token::PLUS: return error_not_implemented();
|
||||||
case Token::MINUS: return error_not_implemented();
|
case Token::MINUS: return error_not_implemented();
|
||||||
case '~': return set_ret(bld_->create_neg(op));
|
case '~': return set_ret(bld_->create_neg(arg));
|
||||||
case '!': return set_ret(bld_->create_not(op));
|
case '!': return set_ret(bld_->create_not(arg));
|
||||||
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
case '^': return set_ret(bld_->create_trans(op));
|
case '^': return set_ret(bld_->create_trans(arg));
|
||||||
|
case Token::REDUCE: {
|
||||||
|
int ax, tag;
|
||||||
|
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||||
|
bool is_float = arg_scal_ty->is_floating_point_ty();
|
||||||
|
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
|
||||||
|
return set_ret(bld_->create_reduce(arg, op, ax));
|
||||||
|
}
|
||||||
default: error_not_implemented();
|
default: error_not_implemented();
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return error_not_implemented();
|
||||||
@@ -412,16 +431,41 @@ void Generator::Gen(ir::module *mod) {
|
|||||||
|
|
||||||
|
|
||||||
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
||||||
|
if(src->get_type() == dst_ty)
|
||||||
|
return src;
|
||||||
if(dst_ty->is_tile_ty()) {
|
if(dst_ty->is_tile_ty()) {
|
||||||
ir::type *src_ty = src->get_type();
|
ir::type *src_ty = src->get_type();
|
||||||
auto dst_shapes = dst_ty->get_tile_shapes();
|
auto dst_shapes = dst_ty->get_tile_shapes();
|
||||||
if(!src_ty->is_tile_ty())
|
if(!src_ty->is_tile_ty())
|
||||||
return bld_->create_splat(src, dst_shapes);
|
return bld_->create_splat(src, dst_shapes);
|
||||||
auto src_shapes = src_ty->get_tile_shapes();
|
auto src_shapes = src_ty->get_tile_shapes();
|
||||||
if(src_shapes.size() != dst_shapes.size())
|
if(src_shapes.size() != dst_shapes.size()){
|
||||||
return bld_->create_reshape(src, dst_shapes);
|
unsigned src_numel = 1;
|
||||||
else
|
for(unsigned s: src_shapes)
|
||||||
|
src_numel *= s;
|
||||||
|
unsigned dst_numel = 1;
|
||||||
|
for(unsigned s: dst_shapes)
|
||||||
|
dst_numel *= s;
|
||||||
|
if(src_numel == dst_numel)
|
||||||
|
return bld_->create_reshape(src, dst_shapes);
|
||||||
|
else {
|
||||||
|
auto padded_shapes = src_shapes;
|
||||||
|
while(padded_shapes.size() != dst_shapes.size())
|
||||||
|
padded_shapes.insert(padded_shapes.begin(), 1);
|
||||||
|
// check that broadcast is legal
|
||||||
|
for(size_t d = 0; d < padded_shapes.size(); d++){
|
||||||
|
if(dst_shapes[d] != padded_shapes[d] &&
|
||||||
|
padded_shapes[d] != 1)
|
||||||
|
should_not_happen();
|
||||||
|
}
|
||||||
|
// pad and broadcast
|
||||||
|
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
||||||
|
return bld_->create_broadcast(padded, dst_shapes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
return bld_->create_broadcast(src, dst_shapes);
|
return bld_->create_broadcast(src, dst_shapes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
@@ -453,21 +453,52 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
TileType::ShapeInt shape;
|
TileType::ShapeInt shape;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
const Token* tok;
|
const Token* tok;
|
||||||
|
std::vector<std::pair<int, int>> redInfo;
|
||||||
do {
|
do {
|
||||||
tok = ts_.Next();
|
tok = ts_.Next();
|
||||||
if(tok->tag_ == ':')
|
switch(tok->tag_) {
|
||||||
shape.push_back(lhsShape[i++]);
|
case ':':
|
||||||
else if(tok->tag_ == Token::NEWAXIS)
|
shape.push_back(lhsShape[i++]);
|
||||||
shape.push_back(1);
|
break;
|
||||||
else
|
|
||||||
Error(tok, "only ':' and newaxis are supported in subscripts");
|
case Token::NEWAXIS:
|
||||||
|
shape.push_back(1);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case Token::ADD:
|
||||||
|
case Token::SUB:
|
||||||
|
case Token::MAX:
|
||||||
|
case Token::MIN:{
|
||||||
|
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||||
|
redInfo.push_back({i, info});
|
||||||
|
shape.push_back(lhsShape[i++]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}while(ts_.Try(','));
|
}while(ts_.Try(','));
|
||||||
ts_.Expect(']');
|
ts_.Expect(']');
|
||||||
if(lhsShape.size() > i)
|
if(lhsShape.size() > i)
|
||||||
Error(tok, "broadcasting not using all operand axes");
|
Error(tok, "broadcasting not using all operand axes");
|
||||||
// create ret tile
|
// create ret tile
|
||||||
TileType *retType = TileType::New(shape, lhsQual);
|
Expr* res = lhs;
|
||||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
for(auto r: redInfo){
|
||||||
|
shape.erase(shape.begin() + r.first);
|
||||||
|
Type *retType;
|
||||||
|
if(shape.empty())
|
||||||
|
retType = lhsQual.GetPtr();
|
||||||
|
else
|
||||||
|
retType = TileType::New(shape, lhsQual);
|
||||||
|
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
|
||||||
|
}
|
||||||
|
if(!shape.empty()){
|
||||||
|
TileType *retType = TileType::New(shape, lhsQual);
|
||||||
|
res = UnaryOp::New(Token::CAST, res, retType);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -54,6 +54,8 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
|||||||
{ "_Noreturn", Token::NORETURN },
|
{ "_Noreturn", Token::NORETURN },
|
||||||
{ "_Static_assert", Token::STATIC_ASSERT },
|
{ "_Static_assert", Token::STATIC_ASSERT },
|
||||||
{ "_Thread_local", Token::THREAD },
|
{ "_Thread_local", Token::THREAD },
|
||||||
|
{ "max", Token::MAX },
|
||||||
|
{ "min", Token::MIN },
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
||||||
|
@@ -157,6 +157,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
|||||||
for(auto it: opt_space_.defines)
|
for(auto it: opt_space_.defines)
|
||||||
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
cpp.AddMacro(it.first, &opt.defines.at(it.first));
|
||||||
cpp.Process(tokens);
|
cpp.Process(tokens);
|
||||||
|
// tokens.Print(stdout);
|
||||||
// parse
|
// parse
|
||||||
Parser parser(tokens);
|
Parser parser(tokens);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
@@ -164,11 +165,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
|||||||
auto ir = make_ir(parser);
|
auto ir = make_ir(parser);
|
||||||
// binary code-gen
|
// binary code-gen
|
||||||
std::unique_ptr<driver::module> bin;
|
std::unique_ptr<driver::module> bin;
|
||||||
try{
|
bin = make_bin(*ir, stream->context(), opt);
|
||||||
bin = make_bin(*ir, stream->context(), opt);
|
|
||||||
}catch(const std::runtime_error& e) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// kernel uses too much resources
|
// kernel uses too much resources
|
||||||
if(!bin)
|
if(!bin)
|
||||||
return;
|
return;
|
||||||
@@ -204,6 +201,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||||
|
// ir::print(module, std::cout);
|
||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
27
tests/common/src/reduce.h
Normal file
27
tests/common/src/reduce.h
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
namespace src {
|
||||||
|
|
||||||
|
const char *reduce1d =
|
||||||
|
R"(
|
||||||
|
void reduce1d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
|
TYPE * Y __noalias __readonly __aligned(16),
|
||||||
|
int N) {
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
|
||||||
|
const char *reduce2d =
|
||||||
|
R"(
|
||||||
|
void reduce2d(TYPE * X __noalias __readonly __aligned(16),
|
||||||
|
TYPE * Y __noalias __writeonly __aligned(16),
|
||||||
|
int M, int N, int ldx) {
|
||||||
|
int ridm = get_program_id(0);
|
||||||
|
int ridn = get_program_id(1);
|
||||||
|
int rm[TM] = ridm * TM + 0 ... TM;
|
||||||
|
int rn[TN] = ridn * TN + 0 ... TN;
|
||||||
|
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||||
|
TYPE* py[TY] = Y + RY;
|
||||||
|
*py = (*px)[RED];
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
}
|
@@ -9,6 +9,10 @@
|
|||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
|
/* ------------------------
|
||||||
|
* Launch Grid
|
||||||
|
* ------------------------ */
|
||||||
|
|
||||||
inline size_t ceil(size_t x, size_t y) {
|
inline size_t ceil(size_t x, size_t y) {
|
||||||
return (x + y - 1) / y;
|
return (x + y - 1) / y;
|
||||||
};
|
};
|
||||||
@@ -26,12 +30,116 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* ------------------------
|
||||||
|
* Tensor Initialization
|
||||||
|
* ------------------------ */
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void init_rand(std::vector<T>& x) {
|
||||||
|
for(size_t i = 0; i < x.size(); i++)
|
||||||
|
x[i] = static_cast<T>((double)rand()/RAND_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void init_zeros(std::vector<T>& x) {
|
||||||
|
for(size_t i = 0; i < x.size(); i++)
|
||||||
|
x[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ------------------------
|
||||||
|
* Loop Nests
|
||||||
|
* ------------------------ */
|
||||||
|
|
||||||
|
void _loop_nest(std::vector<int> const & ranges,
|
||||||
|
std::function<void(std::vector<int> const &)> const & f){
|
||||||
|
int D = ranges.size();
|
||||||
|
std::vector<int> values(D, 0);
|
||||||
|
// Start with innermost loop
|
||||||
|
int i = D - 1;
|
||||||
|
while(true){
|
||||||
|
// Execute function
|
||||||
|
f(values);
|
||||||
|
while(values[i]++ == ranges[i] - 1){
|
||||||
|
if(i == 0)
|
||||||
|
return;
|
||||||
|
values[i--] = 0;
|
||||||
|
}
|
||||||
|
i = D - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* -----------------------
|
||||||
|
* TENSOR INDEXING
|
||||||
|
* ----------------------- */
|
||||||
|
|
||||||
enum order_t {
|
enum order_t {
|
||||||
ROWMAJOR,
|
ROWMAJOR,
|
||||||
COLMAJOR
|
COLMAJOR
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
int offset(const std::vector<int>& idx, const std::vector<int>& shapes) {
|
||||||
|
int result = idx[0];
|
||||||
|
for(int i = 1; i < idx.size(); i++)
|
||||||
|
result += idx[i]*shapes[i-1];
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* -----------------------
|
||||||
|
* REDUCTION HELPERS
|
||||||
|
* ----------------------- */
|
||||||
|
|
||||||
|
enum reduce_op_t {
|
||||||
|
ADD,
|
||||||
|
MAX,
|
||||||
|
MIN
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string to_str(reduce_op_t op) {
|
||||||
|
switch (op) {
|
||||||
|
case ADD: return "+";
|
||||||
|
case MAX: return "max";
|
||||||
|
case MIN: return "min";
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
std::function<T(T,T)> get_accumulator(reduce_op_t op) {
|
||||||
|
switch (op) {
|
||||||
|
case ADD: return [](T x, T y) { return x + y; };
|
||||||
|
case MAX: return [](T x, T y) { return std::max(x, y); };
|
||||||
|
case MIN: return [](T x, T y) { return std::min(x, y); };
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
assert(false);
|
||||||
|
return std::function<T(T,T)>();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* -----------------------
|
||||||
|
* TENSOR COMPARISON
|
||||||
|
* ----------------------- */
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
|
||||||
|
if(hc.size() != rc.size())
|
||||||
|
return false;
|
||||||
|
for(size_t i = 0; i < hc.size(); i++)
|
||||||
|
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||||
|
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* -----------------------
|
||||||
|
* PRETTY PRINTING
|
||||||
|
* ----------------------- */
|
||||||
|
|
||||||
namespace aux{
|
namespace aux{
|
||||||
template<std::size_t...> struct seq{};
|
template<std::size_t...> struct seq{};
|
||||||
|
|
||||||
@@ -57,22 +165,23 @@ auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
|
|||||||
return os << ")";
|
return os << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class Ch, class Tr, class T>
|
||||||
namespace testing {
|
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, const std::vector<T>& vec) {
|
||||||
|
os << "{";
|
||||||
template<class T>
|
for(size_t i = 0; i < vec.size(); i++){
|
||||||
bool diff(const std::vector<T>& hc, const std::vector<T>& rc) {
|
if(i > 0)
|
||||||
if(hc.size() != rc.size())
|
os << ", ";
|
||||||
return false;
|
os << vec[i];
|
||||||
for(size_t i = 0; i < hc.size(); i++)
|
}
|
||||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
os << "}";
|
||||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
return os;
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class Ch, class Tr>
|
||||||
|
std::basic_ostream<Ch, Tr>& operator<<(std::basic_ostream<Ch, Tr>& os, reduce_op_t op) {
|
||||||
|
return os << to_str(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
foreach(PROG dot)
|
foreach(PROG dot reduce)
|
||||||
set(TARGET unit_${PROG})
|
set(TARGET unit_${PROG})
|
||||||
add_executable(${TARGET} ${PROG}.cc)
|
add_executable(${TARGET} ${PROG}.cc)
|
||||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})
|
||||||
|
@@ -50,7 +50,7 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, int nwarp){
|
||||||
typedef float NumericT;
|
typedef float NumericT;
|
||||||
std::string ty = "float";
|
std::string ty = "float";
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
@@ -62,12 +62,9 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
|
|||||||
int32_t ldb = BT ? N : K;
|
int32_t ldb = BT ? N : K;
|
||||||
int32_t ldc = M;
|
int32_t ldc = M;
|
||||||
srand(0);
|
srand(0);
|
||||||
for(size_t i = 0; i < ha.size(); i++)
|
init_rand(ha);
|
||||||
ha[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
init_rand(hb);
|
||||||
for(size_t i = 0; i < hb.size(); i++)
|
init_rand(hc);
|
||||||
hb[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
|
||||||
for(size_t i = 0; i < hc.size(); i++)
|
|
||||||
hc[i] = static_cast<NumericT>((double)0);
|
|
||||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
||||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
||||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
||||||
@@ -94,7 +91,7 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
|
|||||||
stream->read(&*dc, true, 0, hc);
|
stream->read(&*dc, true, 0, hc);
|
||||||
std::vector<NumericT> rc(hc.size());
|
std::vector<NumericT> rc(hc.size());
|
||||||
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||||
return testing::diff(hc, rc);
|
return diff(hc, rc);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
106
tests/unit/reduce.cc
Normal file
106
tests/unit/reduce.cc
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
#include <iomanip>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <functional>
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
#include "triton/external/half.hpp"
|
||||||
|
#include "triton/runtime/function.h"
|
||||||
|
#include "src/reduce.h"
|
||||||
|
#include "cuda/cublas.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
namespace drv = triton::driver;
|
||||||
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void reduce_nd(std::vector<T> &y, const std::vector<T> &x, reduce_op_t op, size_t axis, const std::vector<int>& shapes) {
|
||||||
|
assert(axis <= shapes.size() - 1);
|
||||||
|
// remove shape at index axis to get outer dimensions
|
||||||
|
std::vector<int> outer = shapes;
|
||||||
|
outer.erase(outer.begin() + axis);
|
||||||
|
// retrieve shape at index axis to get inner dimension
|
||||||
|
int inner = shapes[axis];
|
||||||
|
// accumualtion function
|
||||||
|
auto acc = get_accumulator<T>(op);
|
||||||
|
// iterate over outer dimensions
|
||||||
|
_loop_nest(outer, [&](const std::vector<int>& y_idx) {
|
||||||
|
T ret = 0;
|
||||||
|
auto x_idx = y_idx;
|
||||||
|
x_idx.insert(x_idx.begin() + axis, 0);
|
||||||
|
// accumulate over inner dimensions
|
||||||
|
for(int z = 0; z < inner; z++){
|
||||||
|
x_idx[axis] = z;
|
||||||
|
ret = acc(ret, x[offset(x_idx, shapes)]);
|
||||||
|
}
|
||||||
|
y[offset(y_idx, outer)] = ret;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool do_test(drv::stream* stream, std::vector<int> shape, int axis, reduce_op_t op, int nwarp){
|
||||||
|
typedef float NumericT;
|
||||||
|
std::string ty = "float";
|
||||||
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
|
drv::context* context = stream->context();
|
||||||
|
size_t axy = (axis == 0) ? 1 : 0;
|
||||||
|
std::string RY = (axis == 0) ? "rn" : "rm";
|
||||||
|
std::vector<NumericT> hy(shape[axy]);
|
||||||
|
std::vector<NumericT> ry(shape[axy]);
|
||||||
|
std::vector<NumericT> hx(shape[0]*shape[1]);
|
||||||
|
srand(0);
|
||||||
|
init_zeros(hy);
|
||||||
|
init_rand(hx);
|
||||||
|
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
|
||||||
|
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
|
||||||
|
stream->write(&*dy, true, 0, hy);
|
||||||
|
stream->write(&*dx, true, 0, hx);
|
||||||
|
rt::function::options_space_t opt;
|
||||||
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
|
opt.defines.push_back({"TM", {std::to_string(shape[0])}});
|
||||||
|
opt.defines.push_back({"TN", {std::to_string(shape[1])}});
|
||||||
|
opt.defines.push_back({"TY", {std::to_string(shape[axy])}});
|
||||||
|
opt.defines.push_back({"RY", {RY}});
|
||||||
|
std::string RED = "";
|
||||||
|
for(int n = 0; n < 2; n++){
|
||||||
|
if(n > 0)
|
||||||
|
RED += ", ";
|
||||||
|
RED += (n==axis) ? to_str(op) : ":";
|
||||||
|
}
|
||||||
|
opt.defines.push_back({"RED", {RED}});
|
||||||
|
opt.num_warps = {nwarp};
|
||||||
|
rt::function function(src::reduce2d, opt);
|
||||||
|
function({&*dx, &*dy, shape[0], shape[1], shape[0]}, grid2d(shape[0], shape[1]), stream);
|
||||||
|
stream->synchronize();
|
||||||
|
stream->read(&*dy, true, 0, hy);
|
||||||
|
reduce_nd(ry, hx, op, axis, shape);
|
||||||
|
return diff(hy, ry);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// initialize default compute device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
|
// shapes to benchmark
|
||||||
|
typedef std::tuple<std::vector<int>, int, reduce_op_t> config_t;
|
||||||
|
std::vector<config_t> configs = {
|
||||||
|
config_t{{32, 32}, 0, MAX},
|
||||||
|
config_t{{32, 32}, 1, ADD},
|
||||||
|
config_t{{32, 64}, 0, ADD},
|
||||||
|
config_t{{64, 32}, 1, ADD}
|
||||||
|
};
|
||||||
|
// does the work
|
||||||
|
int axis;
|
||||||
|
std::vector<int> shape;
|
||||||
|
reduce_op_t op;
|
||||||
|
for(const auto& c: configs){
|
||||||
|
std::tie(shape, axis, op) = c;
|
||||||
|
std::cout << "Testing " << c << " ... " << std::flush;
|
||||||
|
if(do_test(stream, shape, axis, op, 1))
|
||||||
|
std::cout << " Pass! " << std::endl;
|
||||||
|
else
|
||||||
|
std::cout << " Fail! " << std::endl;
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user