[tests] basic test for reduction in python passes

This commit is contained in:
Philippe Tillet
2019-09-11 17:35:56 -04:00
parent 2781cdcf93
commit 04a0fbd8e3
10 changed files with 120 additions and 22 deletions

View File

@@ -418,22 +418,25 @@ class UnaryOp : public Expr {
friend class LValAssigner; friend class LValAssigner;
public: public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr); static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {} virtual ~UnaryOp() {}
virtual void Accept(Visitor* v); virtual void Accept(Visitor* v);
virtual bool IsLVal(); virtual bool IsLVal();
::Type *Convert(); ::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking(); void TypeChecking();
void IncDecOpTypeChecking(); void IncDecOpTypeChecking();
void AddrOpTypeChecking(); void AddrOpTypeChecking();
void DerefOpTypeChecking(); void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void TransOpTypeChecking(); void TransOpTypeChecking();
void UnaryArithmOpTypeChecking(); void UnaryArithmOpTypeChecking();
void CastOpTypeChecking(); void CastOpTypeChecking();
protected: protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr) UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op) { : Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand; operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) { if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand); operand_ = MayCast(operand);
@@ -441,6 +444,7 @@ protected:
} }
int op_; int op_;
int info_;
Expr* operand_; Expr* operand_;
}; };

View File

@@ -180,9 +180,7 @@ public:
PLUS, PLUS,
MINUS, MINUS,
CAST, CAST,
REDUCE_ADD, REDUCE,
REDUCE_MAX,
REDUCE_MIN,
// For preprocessor // For preprocessor
PP_IF, PP_IF,

View File

@@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
tgt_->add_barrier(module, builder); tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr); builder.CreateStore(result, write_ptr);
// build result // build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value(); unsigned shape_ax = op->get_type()->get_tile_shapes()[axis];
unsigned per_thread = op_tile->axis(axis).values.size();
unsigned depth = shape_ax / per_thread;
for(unsigned i = depth/2; i > 0; i >>= 1){ for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices // current indices
indices_t current(write_idx.size(), builder.getInt32(0)); indices_t current(write_idx.size(), builder.getInt32(0));

View File

@@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() {
} }
void BinaryOp::MaskedDerefOpTypeChecking() { void BinaryOp::MaskedDerefOpTypeChecking() {
// auto lhsTileType = lhs_->Type()->ToTile();
// auto rhsTileType = rhs_->Type()->ToTile();
::Type* lhsScalType = TryExtractScalarType(this, lhs_); ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm(); auto lhsType = lhsScalType->ToArithm();
@@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() {
* Unary Operators * Unary Operators
*/ */
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) { UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type); auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
ret->pool_ = &unaryOpPool; ret->pool_ = &unaryOpPool;
ret->TypeChecking(); ret->TypeChecking();
@@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
} }
int UnaryOp::encodeRed(int ax, int tag) {
int result = 0;
result |= ax;
result |= tag << 16;
return result;
}
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
ax = info & 0x0000FFFF;
tag = (info & 0xFFFF0000) >> 16;
}
bool UnaryOp::IsLVal() { bool UnaryOp::IsLVal() {
// Only deref('*') could be lvalue; // Only deref('*') could be lvalue;
return op_ == Token::DEREF; return op_ == Token::DEREF;
@@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() {
case '^': case '^':
return TransOpTypeChecking(); return TransOpTypeChecking();
case Token::REDUCE:
return ReduceOpTypeChecking();
default: default:
assert(false); assert(false);
} }
@@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() {
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr()); type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
} }
void UnaryOp::ReduceOpTypeChecking() {
int ax, tag;
decodeRed(info_, ax, tag);
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "array expected for reduction operation");
auto shape = tileType->Shape();
shape.erase(shape.begin() + ax);
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::TransOpTypeChecking() { void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile(); auto tileType = operand_->Type()->ToTile();

View File

@@ -174,6 +174,11 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
case '!': return set_ret(bld_->create_not(op)); case '!': return set_ret(bld_->create_not(op));
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_))); case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(op)); case '^': return set_ret(bld_->create_trans(op));
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
return set_ret(bld_->create_reduce(op, ax));
}
default: error_not_implemented(); default: error_not_implemented();
} }
return error_not_implemented(); return error_not_implemented();
@@ -412,16 +417,41 @@ void Generator::Gen(ir::module *mod) {
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) { ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
if(src->get_type() == dst_ty)
return src;
if(dst_ty->is_tile_ty()) { if(dst_ty->is_tile_ty()) {
ir::type *src_ty = src->get_type(); ir::type *src_ty = src->get_type();
auto dst_shapes = dst_ty->get_tile_shapes(); auto dst_shapes = dst_ty->get_tile_shapes();
if(!src_ty->is_tile_ty()) if(!src_ty->is_tile_ty())
return bld_->create_splat(src, dst_shapes); return bld_->create_splat(src, dst_shapes);
auto src_shapes = src_ty->get_tile_shapes(); auto src_shapes = src_ty->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size()) if(src_shapes.size() != dst_shapes.size()){
return bld_->create_reshape(src, dst_shapes); unsigned src_numel = 1;
else for(unsigned s: src_shapes)
src_numel *= s;
unsigned dst_numel = 1;
for(unsigned s: dst_shapes)
dst_numel *= s;
if(src_numel == dst_numel)
return bld_->create_reshape(src, dst_shapes);
else {
auto padded_shapes = src_shapes;
while(padded_shapes.size() != dst_shapes.size())
padded_shapes.insert(padded_shapes.begin(), 1);
// check that broadcast is legal
for(size_t d = 0; d < padded_shapes.size(); d++){
if(dst_shapes[d] != padded_shapes[d] &&
padded_shapes[d] != 1)
should_not_happen();
}
// pad and broadcast
ir::value *padded = bld_->create_reshape(src, padded_shapes);
return bld_->create_broadcast(padded, dst_shapes);
}
}
else{
return bld_->create_broadcast(src, dst_shapes); return bld_->create_broadcast(src, dst_shapes);
}
} }
return src; return src;
} }

View File

@@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
TileType::ShapeInt shape; TileType::ShapeInt shape;
size_t i = 0; size_t i = 0;
const Token* tok; const Token* tok;
std::vector<std::pair<int, int>> redList; std::vector<std::pair<int, int>> redInfo;
do { do {
tok = ts_.Next(); tok = ts_.Next();
switch(tok->tag_) { switch(tok->tag_) {
@@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
shape.push_back(1); shape.push_back(1);
break; break;
// case Token::ADD: case Token::ADD:
// case Token::SUB: case Token::SUB:{
// redList.push_back({i, tok->tag_}); int info = UnaryOp::encodeRed(i, tok->tag_);
// break; redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
default: default:
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i); Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
@@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
if(lhsShape.size() > i) if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes"); Error(tok, "broadcasting not using all operand axes");
// create ret tile // create ret tile
TileType *retType = TileType::New(shape, lhsQual); Expr* res = lhs;
return UnaryOp::New(Token::CAST, lhs, retType); for(auto r: redInfo){
shape.erase(shape.begin() + r.first);
Type *retType;
if(shape.empty())
retType = lhsQual.GetPtr();
else
retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
}
if(!shape.empty()){
TileType *retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::CAST, res, retType);
}
return res;
} }

View File

@@ -204,6 +204,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::peephole peephole; codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&alignment_info, &grids); codegen::transform::reassociate reassociate(&alignment_info, &grids);
codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get()); codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get());
// ir::print(module, std::cout);
// run passes // run passes
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);

View File

@@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16),
int rm[TM] = ridm * TM + 0 ... TM; int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN; int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
TYPE* py[TM, TN] = Y + rm[:, newaxis]; TYPE* py[TM] = Y + rm;
*py = (*px)[:, +]; *py = (*px)[:, +];
} }
)"; )";

View File

@@ -37,6 +37,12 @@ void init_rand(std::vector<T>& x) {
x[i] = static_cast<T>((double)rand()/RAND_MAX); x[i] = static_cast<T>((double)rand()/RAND_MAX);
} }
template<class T>
void init_zeros(std::vector<T>& x) {
for(size_t i = 0; i < x.size(); i++)
x[i] = 0;
}
namespace aux{ namespace aux{

View File

@@ -15,15 +15,26 @@ namespace drv = triton::driver;
namespace rt = triton::runtime; namespace rt = triton::runtime;
template<class T>
void cpu_ref(std::vector<T> &y, const std::vector<T> &x, int M, int N) {
for(int m = 0; m < M; m++){
T acc = 0;
for(int n = 0; n < N; n++)
acc = acc + x[m + n*M];
y[m] = acc;
}
}
bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
typedef float NumericT; typedef float NumericT;
std::string ty = "float"; std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT); size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context(); drv::context* context = stream->context();
std::vector<NumericT> hy(M); std::vector<NumericT> hy(M);
std::vector<NumericT> ry(M);
std::vector<NumericT> hx(M*N); std::vector<NumericT> hx(M*N);
srand(0); srand(0);
init_rand(hy); init_zeros(hy);
init_rand(hx); init_rand(hx);
auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes)); auto dy = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hy.size()*dt_nbytes));
auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes)); auto dx = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hx.size()*dt_nbytes));
@@ -35,8 +46,11 @@ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){
opt.defines.push_back({"TN", {std::to_string(N)}}); opt.defines.push_back({"TN", {std::to_string(N)}});
opt.num_warps = {nwarp}; opt.num_warps = {nwarp};
rt::function function(src::reduce2d, opt); rt::function function(src::reduce2d, opt);
function({&*dy, &*dx, M, N, M}, grid2d(M, N), stream); function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream);
stream->synchronize(); stream->synchronize();
stream->read(&*dy, true, 0, hy);
cpu_ref(ry, hx, M, N);
return testing::diff(hy, ry);
} }
int main() { int main() {