[LANG] Added support for flattening

This commit is contained in:
Philippe Tillet
2020-05-13 20:44:09 -04:00
committed by Philippe Tillet
parent 694bfbddf9
commit bb2d98ce4b

View File

@@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() {
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
Error(this, "tile with more than one element cannot be casted to scalar");
if(type_->IsTile() && operandType->IsTile()){
auto shape = type_->ToTile()->Shape();
auto operandShape = operandType->ToTile()->Shape();
if(operandShape.size() > shape.size())
Error(this, "cast cannot reduce operand rank");
auto shape = type_->ToTile()->Shape();
// this is a shape downcast
if(operandShape.size() > shape.size()){
size_t operandNumel = 1;
size_t numel = 1;
for(auto x: operandShape)
operandNumel *= x;
for(auto x: shape)
numel *= x;
if(operandNumel != numel)
Error(this, "cast cannot change number of elements");
return;
}
// this is a shape upcast
while(operandShape.size() < shape.size())
operandShape.insert(operandShape.begin(), 1);
for(size_t i = 0; i < shape.size(); i++) {