[PYTHON][EXAMPLES] Tentative support for einsum with transpositions

This commit is contained in:
Philippe Tillet
2019-10-25 19:01:21 -04:00
parent 8bd87fa19d
commit 76adcb755a
9 changed files with 137 additions and 46 deletions

View File

@@ -7,6 +7,7 @@
static MemPoolImp<BinaryOp> binaryOpPool;
static MemPoolImp<TransOp> transOpPool;
static MemPoolImp<ConditionalOp> conditionalOpPool;
static MemPoolImp<FuncCall> funcCallPool;
static MemPoolImp<Declaration> initializationPool;
@@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) {
v->VisitUnaryOp(this);
}
void TransOp::Accept(Visitor* v) {
v->VisitTransOp(this);
}
void ConditionalOp::Accept(Visitor* v) {
v->VisitConditionalOp(this);
@@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() {
case Token::CAST:
return CastOpTypeChecking();
case '^':
return TransOpTypeChecking();
case Token::REDUCE:
return ReduceOpTypeChecking();
@@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() {
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto shape = tileType->Shape();
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::UnaryArithmOpTypeChecking() {
auto scalType = TryExtractScalarType(this, operand_);
if (Token::PLUS == op_ || Token::MINUS == op_) {
@@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() {
}
}
/*
* Transposition Operator
*/
void TransOp::TypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto opShape = tileType->Shape();
if(perm_.size() != opShape.size())
Error(this, "invalid permutations");
// permutate input shape
TileType::ShapeInt resShape(opShape.size());
for(int d = 0; d < opShape.size(); d++)
resShape[d] = opShape[perm_[d]];
type_ = TileType::New(resShape, tileType->Derived());
}
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
ret->pool_ = &transOpPool;
ret->TypeChecking();
return ret;
}
/*
* Conditional Operator