[PYTHON][EXAMPLES] Tentative support for einsum with transpositions

This commit is contained in:
Philippe Tillet
2019-10-25 19:01:21 -04:00
parent 8bd87fa19d
commit 76adcb755a
9 changed files with 137 additions and 46 deletions

View File

@@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
QualType lhsQual = lhsTile->Derived();
// create ret shape
TileType::ShapeInt shape;
TileType::ShapeInt axVec;
size_t i = 0;
const Token* tok;
std::vector<std::pair<int, int>> redInfo;
@@ -469,10 +470,22 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
case Token::SUB:
case Token::MAX:
case Token::MIN:{
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
int info = UnaryOp::encodeRed(i, tok->tag_);
redInfo.push_back({i, info});
shape.push_back(lhsShape[i++]);
break;
}
case '^':{
Expr* expr = ParseConditionalExpr();
EnsureInteger(expr);
int ax = Evaluator<long>().Eval(expr);
axVec.push_back(ax);
if(ax < 0 || ax >= lhsShape.size())
Error(tok, "unknown axis %d in transposition", ax);
shape.push_back(lhsShape[ax]);
i++;
break;
}
default:
@@ -481,8 +494,19 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
}
}while(ts_.Try(','));
ts_.Expect(']');
// transposition mode
std::set<int> axSet(axVec.begin(), axVec.end());
if(!axSet.empty()){
if(axSet.size()!=lhsShape.size())
Error(tok, "transposition must address all axes of input array");
return TransOp::New(axVec, lhs);
}
// broadcasting mode
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
Expr* res = lhs;
for(auto r: redInfo){
@@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() {
case '-': return ParseUnaryOp(tok, Token::MINUS);
case '~': return ParseUnaryOp(tok, '~');
case '!': return ParseUnaryOp(tok, '!');
case '^': return ParseUnaryOp(tok, Token::XOR);
case '^': {
auto operand = ParseCastExpr();
TileType::ShapeInt shape = operand->Type()->ToTile()->Shape();
TransOp::PermInt perm(shape.size());
for(int d = 0; d < shape.size(); d++)
perm[d] = d;
std::rotate(perm.begin(), perm.begin() + 1, perm.end());
return TransOp::New(perm, operand);
}
default:
ts_.PutBack();
return ParsePostfixExpr();