[tests] basic test for reduction in python passes
This commit is contained in:
@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
|
||||
friend class LValAssigner;
|
||||
|
||||
public:
|
||||
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
|
||||
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
|
||||
virtual ~UnaryOp() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal();
|
||||
::Type *Convert();
|
||||
static int encodeRed(int ax, int tag);
|
||||
static void decodeRed(int info, int& ax, int& tag);
|
||||
void TypeChecking();
|
||||
void IncDecOpTypeChecking();
|
||||
void AddrOpTypeChecking();
|
||||
void DerefOpTypeChecking();
|
||||
void ReduceOpTypeChecking();
|
||||
void TransOpTypeChecking();
|
||||
void UnaryArithmOpTypeChecking();
|
||||
void CastOpTypeChecking();
|
||||
|
||||
protected:
|
||||
UnaryOp(int op, Expr* operand, QualType type=nullptr)
|
||||
: Expr(operand->Tok(), type), op_(op) {
|
||||
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||
: Expr(operand->Tok(), type), op_(op), info_(info) {
|
||||
operand_ = operand;
|
||||
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
||||
operand_ = MayCast(operand);
|
||||
@@ -441,6 +444,7 @@ protected:
|
||||
}
|
||||
|
||||
int op_;
|
||||
int info_;
|
||||
Expr* operand_;
|
||||
};
|
||||
|
||||
|
@@ -180,9 +180,7 @@ public:
|
||||
PLUS,
|
||||
MINUS,
|
||||
CAST,
|
||||
REDUCE_ADD,
|
||||
REDUCE_MAX,
|
||||
REDUCE_MIN,
|
||||
REDUCE,
|
||||
|
||||
// For preprocessor
|
||||
PP_IF,
|
||||
|
@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
|
||||
unsigned per_thread = op_tile->axis(axis).values.size();
|
||||
unsigned depth = shape_ax / per_thread;
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
|
@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
|
||||
}
|
||||
|
||||
void BinaryOp::MaskedDerefOpTypeChecking() {
|
||||
// auto lhsTileType = lhs_->Type()->ToTile();
|
||||
// auto rhsTileType = rhs_->Type()->ToTile();
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
auto lhsType = lhsScalType->ToArithm();
|
||||
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
|
||||
* Unary Operators
|
||||
*/
|
||||
|
||||
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
||||
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
|
||||
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
|
||||
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
|
||||
ret->pool_ = &unaryOpPool;
|
||||
|
||||
ret->TypeChecking();
|
||||
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
|
||||
}
|
||||
|
||||
|
||||
int UnaryOp::encodeRed(int ax, int tag) {
|
||||
int result = 0;
|
||||
result |= ax;
|
||||
result |= tag << 16;
|
||||
return result;
|
||||
}
|
||||
|
||||
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
|
||||
ax = info & 0x0000FFFF;
|
||||
tag = (info & 0xFFFF0000) >> 16;
|
||||
}
|
||||
|
||||
bool UnaryOp::IsLVal() {
|
||||
// Only deref('*') could be lvalue;
|
||||
return op_ == Token::DEREF;
|
||||
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
|
||||
case '^':
|
||||
return TransOpTypeChecking();
|
||||
|
||||
case Token::REDUCE:
|
||||
return ReduceOpTypeChecking();
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
|
||||
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
||||
}
|
||||
|
||||
void UnaryOp::ReduceOpTypeChecking() {
|
||||
int ax, tag;
|
||||
decodeRed(info_, ax, tag);
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "array expected for reduction operation");
|
||||
auto shape = tileType->Shape();
|
||||
shape.erase(shape.begin() + ax);
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::TransOpTypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
|
@@ -174,6 +174,11 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
case '!': return set_ret(bld_->create_not(op));
|
||||
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
||||
case '^': return set_ret(bld_->create_trans(op));
|
||||
case Token::REDUCE: {
|
||||
int ax, tag;
|
||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||
return set_ret(bld_->create_reduce(op, ax));
|
||||
}
|
||||
default: error_not_implemented();
|
||||
}
|
||||
return error_not_implemented();
|
||||
@@ -412,16 +417,41 @@ void Generator::Gen(ir::module *mod) {
|
||||
|
||||
|
||||
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
||||
if(src->get_type() == dst_ty)
|
||||
return src;
|
||||
if(dst_ty->is_tile_ty()) {
|
||||
ir::type *src_ty = src->get_type();
|
||||
auto dst_shapes = dst_ty->get_tile_shapes();
|
||||
if(!src_ty->is_tile_ty())
|
||||
return bld_->create_splat(src, dst_shapes);
|
||||
auto src_shapes = src_ty->get_tile_shapes();
|
||||
if(src_shapes.size() != dst_shapes.size())
|
||||
return bld_->create_reshape(src, dst_shapes);
|
||||
else
|
||||
if(src_shapes.size() != dst_shapes.size()){
|
||||
unsigned src_numel = 1;
|
||||
for(unsigned s: src_shapes)
|
||||
src_numel *= s;
|
||||
unsigned dst_numel = 1;
|
||||
for(unsigned s: dst_shapes)
|
||||
dst_numel *= s;
|
||||
if(src_numel == dst_numel)
|
||||
return bld_->create_reshape(src, dst_shapes);
|
||||
else {
|
||||
auto padded_shapes = src_shapes;
|
||||
while(padded_shapes.size() != dst_shapes.size())
|
||||
padded_shapes.insert(padded_shapes.begin(), 1);
|
||||
// check that broadcast is legal
|
||||
for(size_t d = 0; d < padded_shapes.size(); d++){
|
||||
if(dst_shapes[d] != padded_shapes[d] &&
|
||||
padded_shapes[d] != 1)
|
||||
should_not_happen();
|
||||
}
|
||||
// pad and broadcast
|
||||
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
||||
return bld_->create_broadcast(padded, dst_shapes);
|
||||
}
|
||||
}
|
||||
else{
|
||||
return bld_->create_broadcast(src, dst_shapes);
|
||||
}
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
@@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
TileType::ShapeInt shape;
|
||||
size_t i = 0;
|
||||
const Token* tok;
|
||||
std::vector<std::pair<int, int>> redList;
|
||||
std::vector<std::pair<int, int>> redInfo;
|
||||
do {
|
||||
tok = ts_.Next();
|
||||
switch(tok->tag_) {
|
||||
@@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
shape.push_back(1);
|
||||
break;
|
||||
|
||||
// case Token::ADD:
|
||||
// case Token::SUB:
|
||||
// redList.push_back({i, tok->tag_});
|
||||
// break;
|
||||
case Token::ADD:
|
||||
case Token::SUB:{
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
||||
@@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
if(lhsShape.size() > i)
|
||||
Error(tok, "broadcasting not using all operand axes");
|
||||
// create ret tile
|
||||
TileType *retType = TileType::New(shape, lhsQual);
|
||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
||||
Expr* res = lhs;
|
||||
for(auto r: redInfo){
|
||||
shape.erase(shape.begin() + r.first);
|
||||
Type *retType;
|
||||
if(shape.empty())
|
||||
retType = lhsQual.GetPtr();
|
||||
else
|
||||
retType = TileType::New(shape, lhsQual);
|
||||
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
|
||||
}
|
||||
if(!shape.empty()){
|
||||
TileType *retType = TileType::New(shape, lhsQual);
|
||||
res = UnaryOp::New(Token::CAST, res, retType);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -204,6 +204,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&alignment_info, &grids);
|
||||
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
|
||||
// ir::print(module, std::cout);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
|
@@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16),
|
||||
int rm[TM] = ridm * TM + 0 ... TM;
|
||||
int rn[TN] = ridn * TN + 0 ... TN;
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis];
|
||||
TYPE* py[TM] = Y + rm;
|
||||
*py = (*px)[:, +];
|
||||
}
|
||||
)";
|
||||
|
@@ -37,6 +37,12 @@ void init_rand(std::vector<T>& x) {
|
||||
x[i] = static_cast<T>((double)rand()/RAND_MAX);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void init_zeros(std::vector<T>& x) {
|
||||
for(size_t i = 0; i < x.size(); i++)
|
||||
x[i] = 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace aux{
|
||||
|
@@ -15,15 +15,26 @@ namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
|
||||
template<class T>
|
||||
void cpu_ref(std::vector<T> &y, const std::vector<T> &x, int M, int N) {
|
||||
for(int m = 0; m < M; m++){
|
||||
T acc = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
acc = acc + x[m + n*M];
|
||||
y[m] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
std::vector<NumericT> hy(M);
|
||||
std::vector<NumericT> ry(M);
|
||||
std::vector<NumericT> hx(M*N);
|
||||
srand(0);
|
||||
init_rand(hy);
|
||||
init_zeros(hy);
|
||||
init_rand(hx);
|
||||
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
|
||||
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
|
||||
@@ -35,8 +46,11 @@ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
|
||||
opt.defines.push_back({"TN", {std::to_string(N)}});
|
||||
opt.num_warps = {nwarp};
|
||||
rt::function function(src::reduce2d, opt);
|
||||
function({&*dy, &*dx, M, N, M}, grid2d(M, N), stream);
|
||||
function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream);
|
||||
stream->synchronize();
|
||||
stream->read(&*dy, true, 0, hy);
|
||||
cpu_ref(ry, hx, M, N);
|
||||
return testing::diff(hy, ry);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
Reference in New Issue
Block a user