[tests] basic test for reduction in python passes

This commit is contained in:
Philippe Tillet
2019-09-11 17:35:56 -04:00
parent 2781cdcf93
commit 04a0fbd8e3
10 changed files with 120 additions and 22 deletions

View File

@@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
TileType::ShapeInt shape;
size_t i = 0;
const Token* tok;
std::vector<std::pair<int, int>> redList;
std::vector<std::pair<int, int>> redInfo;
do {
tok = ts_.Next();
switch(tok->tag_) {
@@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
shape.push_back(1);
break;
// case Token::ADD:
// case Token::SUB:
// redList.push_back({i, tok->tag_});
// break;
case Token::ADD:
case Token::SUB:{
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
default:
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
@@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
Expr* res = lhs;
for(auto r: redInfo){
shape.erase(shape.begin() + r.first);
Type *retType;
if(shape.empty())
retType = lhsQual.GetPtr();
else
retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::REDUCE, res, retType, r.second);
}
if(!shape.empty()){
TileType *retType = TileType::New(shape, lhsQual);
res = UnaryOp::New(Token::CAST, res, retType);
}
return res;
}