[tests] basic test for reduction in python passes
This commit is contained in:
@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
|
|||||||
friend class LValAssigner;
|
friend class LValAssigner;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
|
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
|
||||||
virtual ~UnaryOp() {}
|
virtual ~UnaryOp() {}
|
||||||
virtual void Accept(Visitor* v);
|
virtual void Accept(Visitor* v);
|
||||||
virtual bool IsLVal();
|
virtual bool IsLVal();
|
||||||
::Type *Convert();
|
::Type *Convert();
|
||||||
|
static int encodeRed(int ax, int tag);
|
||||||
|
static void decodeRed(int info, int& ax, int& tag);
|
||||||
void TypeChecking();
|
void TypeChecking();
|
||||||
void IncDecOpTypeChecking();
|
void IncDecOpTypeChecking();
|
||||||
void AddrOpTypeChecking();
|
void AddrOpTypeChecking();
|
||||||
void DerefOpTypeChecking();
|
void DerefOpTypeChecking();
|
||||||
|
void ReduceOpTypeChecking();
|
||||||
void TransOpTypeChecking();
|
void TransOpTypeChecking();
|
||||||
void UnaryArithmOpTypeChecking();
|
void UnaryArithmOpTypeChecking();
|
||||||
void CastOpTypeChecking();
|
void CastOpTypeChecking();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
UnaryOp(int op, Expr* operand, QualType type=nullptr)
|
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||||
: Expr(operand->Tok(), type), op_(op) {
|
: Expr(operand->Tok(), type), op_(op), info_(info) {
|
||||||
operand_ = operand;
|
operand_ = operand;
|
||||||
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
||||||
operand_ = MayCast(operand);
|
operand_ = MayCast(operand);
|
||||||
@@ -441,6 +444,7 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
int op_;
|
int op_;
|
||||||
|
int info_;
|
||||||
Expr* operand_;
|
Expr* operand_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -180,9 +180,7 @@ public:
|
|||||||
PLUS,
|
PLUS,
|
||||||
MINUS,
|
MINUS,
|
||||||
CAST,
|
CAST,
|
||||||
REDUCE_ADD,
|
REDUCE,
|
||||||
REDUCE_MAX,
|
|
||||||
REDUCE_MIN,
|
|
||||||
|
|
||||||
// For preprocessor
|
// For preprocessor
|
||||||
PP_IF,
|
PP_IF,
|
||||||
|
@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
tgt_->add_barrier(module, builder);
|
tgt_->add_barrier(module, builder);
|
||||||
builder.CreateStore(result, write_ptr);
|
builder.CreateStore(result, write_ptr);
|
||||||
// build result
|
// build result
|
||||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
|
||||||
|
unsigned per_thread = op_tile->axis(axis).values.size();
|
||||||
|
unsigned depth = shape_ax / per_thread;
|
||||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||||
// current indices
|
// current indices
|
||||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||||
|
@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::MaskedDerefOpTypeChecking() {
|
void BinaryOp::MaskedDerefOpTypeChecking() {
|
||||||
|
// auto lhsTileType = lhs_->Type()->ToTile();
|
||||||
|
// auto rhsTileType = rhs_->Type()->ToTile();
|
||||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
auto lhsType = lhsScalType->ToArithm();
|
auto lhsType = lhsScalType->ToArithm();
|
||||||
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
|
|||||||
* Unary Operators
|
* Unary Operators
|
||||||
*/
|
*/
|
||||||
|
|
||||||
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
|
||||||
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
|
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
|
||||||
ret->pool_ = &unaryOpPool;
|
ret->pool_ = &unaryOpPool;
|
||||||
|
|
||||||
ret->TypeChecking();
|
ret->TypeChecking();
|
||||||
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int UnaryOp::encodeRed(int ax, int tag) {
|
||||||
|
int result = 0;
|
||||||
|
result |= ax;
|
||||||
|
result |= tag << 16;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
|
||||||
|
ax = info & 0x0000FFFF;
|
||||||
|
tag = (info & 0xFFFF0000) >> 16;
|
||||||
|
}
|
||||||
|
|
||||||
bool UnaryOp::IsLVal() {
|
bool UnaryOp::IsLVal() {
|
||||||
// Only deref('*') could be lvalue;
|
// Only deref('*') could be lvalue;
|
||||||
return op_ == Token::DEREF;
|
return op_ == Token::DEREF;
|
||||||
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
|
|||||||
case '^':
|
case '^':
|
||||||
return TransOpTypeChecking();
|
return TransOpTypeChecking();
|
||||||
|
|
||||||
|
case Token::REDUCE:
|
||||||
|
return ReduceOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
|
|||||||
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UnaryOp::ReduceOpTypeChecking() {
|
||||||
|
int ax, tag;
|
||||||
|
decodeRed(info_, ax, tag);
|
||||||
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
if(!tileType)
|
||||||
|
Error(this, "array expected for reduction operation");
|
||||||
|
auto shape = tileType->Shape();
|
||||||
|
shape.erase(shape.begin() + ax);
|
||||||
|
type_ = TileType::New(shape, tileType->Derived());
|
||||||
|
}
|
||||||
|
|
||||||
void UnaryOp::TransOpTypeChecking() {
|
void UnaryOp::TransOpTypeChecking() {
|
||||||
auto tileType = operand_->Type()->ToTile();
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
@@ -174,6 +174,11 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
case '!': return set_ret(bld_->create_not(op));
|
case '!': return set_ret(bld_->create_not(op));
|
||||||
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
||||||
case '^': return set_ret(bld_->create_trans(op));
|
case '^': return set_ret(bld_->create_trans(op));
|
||||||
|
case Token::REDUCE: {
|
||||||
|
int ax, tag;
|
||||||
|
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||||
|
return set_ret(bld_->create_reduce(op, ax));
|
||||||
|
}
|
||||||
default: error_not_implemented();
|
default: error_not_implemented();
|
||||||
}
|
}
|
||||||
return error_not_implemented();
|
return error_not_implemented();
|
||||||
@@ -412,16 +417,41 @@ void Generator::Gen(ir::module *mod) {
|
|||||||
|
|
||||||
|
|
||||||
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
||||||
|
if(src->get_type() == dst_ty)
|
||||||
|
return src;
|
||||||
if(dst_ty->is_tile_ty()) {
|
if(dst_ty->is_tile_ty()) {
|
||||||
ir::type *src_ty = src->get_type();
|
ir::type *src_ty = src->get_type();
|
||||||
auto dst_shapes = dst_ty->get_tile_shapes();
|
auto dst_shapes = dst_ty->get_tile_shapes();
|
||||||
if(!src_ty->is_tile_ty())
|
if(!src_ty->is_tile_ty())
|
||||||
return bld_->create_splat(src, dst_shapes);
|
return bld_->create_splat(src, dst_shapes);
|
||||||
auto src_shapes = src_ty->get_tile_shapes();
|
auto src_shapes = src_ty->get_tile_shapes();
|
||||||
if(src_shapes.size() != dst_shapes.size())
|
if(src_shapes.size() != dst_shapes.size()){
|
||||||
return bld_->create_reshape(src, dst_shapes);
|
unsigned src_numel = 1;
|
||||||
else
|
for(unsigned s: src_shapes)
|
||||||
|
src_numel *= s;
|
||||||
|
unsigned dst_numel = 1;
|
||||||
|
for(unsigned s: dst_shapes)
|
||||||
|
dst_numel *= s;
|
||||||
|
if(src_numel == dst_numel)
|
||||||
|
return bld_->create_reshape(src, dst_shapes);
|
||||||
|
else {
|
||||||
|
auto padded_shapes = src_shapes;
|
||||||
|
while(padded_shapes.size() != dst_shapes.size())
|
||||||
|
padded_shapes.insert(padded_shapes.begin(), 1);
|
||||||
|
// check that broadcast is legal
|
||||||
|
for(size_t d = 0; d < padded_shapes.size(); d++){
|
||||||
|
if(dst_shapes[d] != padded_shapes[d] &&
|
||||||
|
padded_shapes[d] != 1)
|
||||||
|
should_not_happen();
|
||||||
|
}
|
||||||
|
// pad and broadcast
|
||||||
|
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
||||||
|
return bld_->create_broadcast(padded, dst_shapes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
return bld_->create_broadcast(src, dst_shapes);
|
return bld_->create_broadcast(src, dst_shapes);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return src;
|
return src;
|
||||||
}
|
}
|
||||||
|
@@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
TileType::ShapeInt shape;
|
TileType::ShapeInt shape;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
const Token* tok;
|
const Token* tok;
|
||||||
std::vector<std::pair<int, int>> redList;
|
std::vector<std::pair<int, int>> redInfo;
|
||||||
do {
|
do {
|
||||||
tok = ts_.Next();
|
tok = ts_.Next();
|
||||||
switch(tok->tag_) {
|
switch(tok->tag_) {
|
||||||
@@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
shape.push_back(1);
|
shape.push_back(1);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// case Token::ADD:
|
case Token::ADD:
|
||||||
// case Token::SUB:
|
case Token::SUB:{
|
||||||
// redList.push_back({i, tok->tag_});
|
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||||
// break;
|
redInfo.push_back({i, info});
|
||||||
|
shape.push_back(lhsShape[i++]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
||||||
@@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
if(lhsShape.size() > i)
|
if(lhsShape.size() > i)
|
||||||
Error(tok, "broadcasting not using all operand axes");
|
Error(tok, "broadcasting not using all operand axes");
|
||||||
// create ret tile
|
// create ret tile
|
||||||
TileType *retType = TileType::New(shape, lhsQual);
|
Expr* res = lhs;
|
||||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
for(auto r: redInfo){
|
||||||
|
shape.erase(shape.begin() + r.first);
|
||||||
|
Type *retType;
|
||||||
|
if(shape.empty())
|
||||||
|
retType = lhsQual.GetPtr();
|
||||||
|
else
|
||||||
|
retType = TileType::New(shape, lhsQual);
|
||||||
|
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
|
||||||
|
}
|
||||||
|
if(!shape.empty()){
|
||||||
|
TileType *retType = TileType::New(shape, lhsQual);
|
||||||
|
res = UnaryOp::New(Token::CAST, res, retType);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -204,6 +204,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||||
|
// ir::print(module, std::cout);
|
||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
@@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16),
|
|||||||
int rm[TM] = ridm * TM + 0 ... TM;
|
int rm[TM] = ridm * TM + 0 ... TM;
|
||||||
int rn[TN] = ridn * TN + 0 ... TN;
|
int rn[TN] = ridn * TN + 0 ... TN;
|
||||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||||
TYPE* py[TM, TN] = Y + rm[:, newaxis];
|
TYPE* py[TM] = Y + rm;
|
||||||
*py = (*px)[:, +];
|
*py = (*px)[:, +];
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
@@ -37,6 +37,12 @@ void init_rand(std::vector<T>& x) {
|
|||||||
x[i] = static_cast<T>((double)rand()/RAND_MAX);
|
x[i] = static_cast<T>((double)rand()/RAND_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void init_zeros(std::vector<T>& x) {
|
||||||
|
for(size_t i = 0; i < x.size(); i++)
|
||||||
|
x[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace aux{
|
namespace aux{
|
||||||
|
@@ -15,15 +15,26 @@ namespace drv = triton::driver;
|
|||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void cpu_ref(std::vector<T> &y, const std::vector<T> &x, int M, int N) {
|
||||||
|
for(int m = 0; m < M; m++){
|
||||||
|
T acc = 0;
|
||||||
|
for(int n = 0; n < N; n++)
|
||||||
|
acc = acc + x[m + n*M];
|
||||||
|
y[m] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
||||||
typedef float NumericT;
|
typedef float NumericT;
|
||||||
std::string ty = "float";
|
std::string ty = "float";
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
drv::context* context = stream->context();
|
drv::context* context = stream->context();
|
||||||
std::vector<NumericT> hy(M);
|
std::vector<NumericT> hy(M);
|
||||||
|
std::vector<NumericT> ry(M);
|
||||||
std::vector<NumericT> hx(M*N);
|
std::vector<NumericT> hx(M*N);
|
||||||
srand(0);
|
srand(0);
|
||||||
init_rand(hy);
|
init_zeros(hy);
|
||||||
init_rand(hx);
|
init_rand(hx);
|
||||||
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
|
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
|
||||||
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
|
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
|
||||||
@@ -35,8 +46,11 @@ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
|||||||
opt.defines.push_back({"TN", {std::to_string(N)}});
|
opt.defines.push_back({"TN", {std::to_string(N)}});
|
||||||
opt.num_warps = {nwarp};
|
opt.num_warps = {nwarp};
|
||||||
rt::function function(src::reduce2d, opt);
|
rt::function function(src::reduce2d, opt);
|
||||||
function({&*dy, &*dx, M, N, M}, grid2d(M, N), stream);
|
function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
|
stream->read(&*dy, true, 0, hy);
|
||||||
|
cpu_ref(ry, hx, M, N);
|
||||||
|
return testing::diff(hy, ry);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
Reference in New Issue
Block a user