[tests] basic test for reduction in python passes

This commit is contained in:
Philippe Tillet
2019-09-11 17:35:56 -04:00
parent 2781cdcf93
commit 04a0fbd8e3
10 changed files with 120 additions and 22 deletions

View File

@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
friend class LValAssigner;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr)
: Expr(operand->Tok(), type), op_(op) {
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
@@ -441,6 +444,7 @@ protected:
}
int op_;
int info_;
Expr* operand_;
};

View File

@@ -180,9 +180,7 @@ public:
PLUS,
MINUS,
CAST,
REDUCE_ADD,
REDUCE_MAX,
REDUCE_MIN,
REDUCE,
// For preprocessor
PP_IF,

View File

@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
unsigned per_thread = op_tile->axis(axis).values.size();
unsigned depth = shape_ax / per_thread;
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));

View File

@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
}
void BinaryOp::MaskedDerefOpTypeChecking() {
// auto lhsTileType = lhs_->Type()->ToTile();
// auto rhsTileType = rhs_->Type()->ToTile();
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
* Unary Operators
*/
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
ret->pool_ = &unaryOpPool;
ret->TypeChecking();
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
}
int UnaryOp::encodeRed(int ax, int tag) {
int result = 0;
result |= ax;
result |= tag << 16;
return result;
}
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
ax = info & 0x0000FFFF;
tag = (info & 0xFFFF0000) >> 16;
}
bool UnaryOp::IsLVal() {
// Only deref('*') could be lvalue;
return op_ == Token::DEREF;
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
case '^':
return TransOpTypeChecking();
case Token::REDUCE:
return ReduceOpTypeChecking();
default:
assert(false);
}
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
}
void UnaryOp::ReduceOpTypeChecking() {
int ax, tag;
decodeRed(info_, ax, tag);
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "array expected for reduction operation");
auto shape = tileType->Shape();
shape.erase(shape.begin() + ax);
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();

View File

@@ -174,6 +174,11 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
case '!': return set_ret(bld_->create_not(op));
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(op));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
return set_ret(bld_->create_reduce(op, ax));
}
default: error_not_implemented();
}
return error_not_implemented();
@@ -412,16 +417,41 @@ void Generator::Gen(ir::module *mod) {
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
if(src->get_type() == dst_ty)
return src;
if(dst_ty->is_tile_ty()) {
ir::type *src_ty = src->get_type();
auto dst_shapes = dst_ty->get_tile_shapes();
if(!src_ty->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src_ty->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size())
return bld_->create_reshape(src, dst_shapes);
else
if(src_shapes.size() != dst_shapes.size()){
unsigned src_numel = 1;
for(unsigned s: src_shapes)
src_numel *= s;
unsigned dst_numel = 1;
for(unsigned s: dst_shapes)
dst_numel *= s;
if(src_numel == dst_numel)
return bld_->create_reshape(src, dst_shapes);
else {
auto padded_shapes = src_shapes;
while(padded_shapes.size() != dst_shapes.size())
padded_shapes.insert(padded_shapes.begin(), 1);
// check that broadcast is legal
for(size_t d = 0; d < padded_shapes.size(); d++){
if(dst_shapes[d] != padded_shapes[d] &&
padded_shapes[d] != 1)
should_not_happen();
}
// pad and broadcast
ir::value *padded = bld_->create_reshape(src, padded_shapes);
return bld_->create_broadcast(padded, dst_shapes);
}
}
else{
return bld_->create_broadcast(src, dst_shapes);
}
}
return src;
}

View File

@@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
TileType::ShapeInt shape;
size_t i = 0;
const Token* tok;
std::vector<std::pair<int, int>> redList;
std::vector<std::pair<int, int>> redInfo;
do {
tok = ts_.Next();
switch(tok->tag_) {
@@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
shape.push_back(1);
break;
// case Token::ADD:
// case Token::SUB:
// redList.push_back({i, tok->tag_});
// break;
case Token::ADD:
case Token::SUB:{
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
default:
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
@@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
Expr* res = lhs;
for(auto r: redInfo){
shape.erase(shape.begin() + r.first);
Type *retType;
if(shape.empty())
retType = lhsQual.GetPtr();
else
retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
}
if(!shape.empty()){
TileType *retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::CAST, res, retType);
}
return res;
}

View File

@@ -204,6 +204,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
// ir::print(module, std::cout);
// run passes
peephole.run(module);
dce.run(module);

View File

@@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16),
int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
TYPE* py[TM, TN] = Y + rm[:, newaxis];
TYPE* py[TM] = Y + rm;
*py = (*px)[:, +];
}
)";

View File

@@ -37,6 +37,12 @@ void init_rand(std::vector<T>& x) {
x[i] = static_cast<T>((double)rand()/RAND_MAX);
}
template<class T>
void init_zeros(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 0;
}
namespace aux{

View File

@@ -15,15 +15,26 @@ namespace drv = triton::driver;
namespace rt = triton::runtime;
template<class T>
void cpu_ref(std::vector<T> &y, const std::vector<T> &x, int M, int N) {
for(int m = 0; m < M; m++){
T acc = 0;
for(int n = 0; n < N; n++)
acc = acc + x[m + n*M];
y[m] = acc;
}
}
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
typedef float NumericT;
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
std::vector<NumericT> hy(M);
std::vector<NumericT> ry(M);
std::vector<NumericT> hx(M*N);
srand(0);
init_rand(hy);
init_zeros(hy);
init_rand(hx);
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
@@ -35,8 +46,11 @@ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
opt.defines.push_back({"TN", {std::to_string(N)}});
opt.num_warps = {nwarp};
rt::function function(src::reduce2d, opt);
function({&*dy, &*dx, M, N, M}, grid2d(M, N), stream);
function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cpu_ref(ry, hx, M, N);
return testing::diff(hy, ry);
}
int main() {