[PYTHON][EXAMPLES] Tentative support for einsum with transpositions
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
|
||||
|
||||
static MemPoolImp<BinaryOp> binaryOpPool;
|
||||
static MemPoolImp<TransOp> transOpPool;
|
||||
static MemPoolImp<ConditionalOp> conditionalOpPool;
|
||||
static MemPoolImp<FuncCall> funcCallPool;
|
||||
static MemPoolImp<Declaration> initializationPool;
|
||||
@@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) {
|
||||
v->VisitUnaryOp(this);
|
||||
}
|
||||
|
||||
void TransOp::Accept(Visitor* v) {
|
||||
v->VisitTransOp(this);
|
||||
}
|
||||
|
||||
void ConditionalOp::Accept(Visitor* v) {
|
||||
v->VisitConditionalOp(this);
|
||||
@@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() {
|
||||
case Token::CAST:
|
||||
return CastOpTypeChecking();
|
||||
|
||||
case '^':
|
||||
return TransOpTypeChecking();
|
||||
|
||||
case Token::REDUCE:
|
||||
return ReduceOpTypeChecking();
|
||||
|
||||
@@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() {
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::TransOpTypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "tile expected for transposition operator '^'");
|
||||
auto shape = tileType->Shape();
|
||||
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::UnaryArithmOpTypeChecking() {
|
||||
auto scalType = TryExtractScalarType(this, operand_);
|
||||
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
||||
@@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() {
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Transposition Operator
|
||||
*/
|
||||
void TransOp::TypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "tile expected for transposition operator '^'");
|
||||
auto opShape = tileType->Shape();
|
||||
if(perm_.size() != opShape.size())
|
||||
Error(this, "invalid permutations");
|
||||
// permutate input shape
|
||||
TileType::ShapeInt resShape(opShape.size());
|
||||
for(int d = 0; d < opShape.size(); d++)
|
||||
resShape[d] = opShape[perm_[d]];
|
||||
type_ = TileType::New(resShape, tileType->Derived());
|
||||
}
|
||||
|
||||
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
|
||||
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
|
||||
ret->pool_ = &transOpPool;
|
||||
ret->TypeChecking();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Conditional Operator
|
||||
|
Reference in New Issue
Block a user