[PYTHON][EXAMPLES] Tentative support for einsum with transpositions
This commit is contained in:
@@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
QualType lhsQual = lhsTile->Derived();
|
||||
// create ret shape
|
||||
TileType::ShapeInt shape;
|
||||
TileType::ShapeInt axVec;
|
||||
size_t i = 0;
|
||||
const Token* tok;
|
||||
std::vector<std::pair<int, int>> redInfo;
|
||||
@@ -469,10 +470,22 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
case Token::SUB:
|
||||
case Token::MAX:
|
||||
case Token::MIN:{
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
break;
|
||||
int info = UnaryOp::encodeRed(i, tok->tag_);
|
||||
redInfo.push_back({i, info});
|
||||
shape.push_back(lhsShape[i++]);
|
||||
break;
|
||||
}
|
||||
|
||||
case '^':{
|
||||
Expr* expr = ParseConditionalExpr();
|
||||
EnsureInteger(expr);
|
||||
int ax = Evaluator<long>().Eval(expr);
|
||||
axVec.push_back(ax);
|
||||
if(ax < 0 || ax >= lhsShape.size())
|
||||
Error(tok, "unknown axis %d in transposition", ax);
|
||||
shape.push_back(lhsShape[ax]);
|
||||
i++;
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
@@ -481,8 +494,19 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
}
|
||||
}while(ts_.Try(','));
|
||||
ts_.Expect(']');
|
||||
|
||||
// transposition mode
|
||||
std::set<int> axSet(axVec.begin(), axVec.end());
|
||||
if(!axSet.empty()){
|
||||
if(axSet.size()!=lhsShape.size())
|
||||
Error(tok, "transposition must address all axes of input array");
|
||||
return TransOp::New(axVec, lhs);
|
||||
}
|
||||
|
||||
// broadcasting mode
|
||||
if(lhsShape.size() > i)
|
||||
Error(tok, "broadcasting not using all operand axes");
|
||||
|
||||
// create ret tile
|
||||
Expr* res = lhs;
|
||||
for(auto r: redInfo){
|
||||
@@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() {
|
||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||
case '~': return ParseUnaryOp(tok, '~');
|
||||
case '!': return ParseUnaryOp(tok, '!');
|
||||
case '^': return ParseUnaryOp(tok, Token::XOR);
|
||||
case '^': {
|
||||
auto operand = ParseCastExpr();
|
||||
TileType::ShapeInt shape = operand->Type()->ToTile()->Shape();
|
||||
TransOp::PermInt perm(shape.size());
|
||||
for(int d = 0; d < shape.size(); d++)
|
||||
perm[d] = d;
|
||||
std::rotate(perm.begin(), perm.begin() + 1, perm.end());
|
||||
return TransOp::New(perm, operand);
|
||||
}
|
||||
default:
|
||||
ts_.PutBack();
|
||||
return ParsePostfixExpr();
|
||||
|
Reference in New Issue
Block a user