[LANG] Added support for flattening
This commit is contained in:
committed by
Philippe Tillet
parent
694bfbddf9
commit
bb2d98ce4b
@@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() {
|
||||
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
||||
Error(this, "tile with more than one element cannot be casted to scalar");
|
||||
if(type_->IsTile() && operandType->IsTile()){
|
||||
auto shape = type_->ToTile()->Shape();
|
||||
auto operandShape = operandType->ToTile()->Shape();
|
||||
if(operandShape.size() > shape.size())
|
||||
Error(this, "cast cannot reduce operand rank");
|
||||
auto shape = type_->ToTile()->Shape();
|
||||
// this is a shape downcast
|
||||
if(operandShape.size() > shape.size()){
|
||||
size_t operandNumel = 1;
|
||||
size_t numel = 1;
|
||||
for(auto x: operandShape)
|
||||
operandNumel *= x;
|
||||
for(auto x: shape)
|
||||
numel *= x;
|
||||
if(operandNumel != numel)
|
||||
Error(this, "cast cannot change number of elements");
|
||||
return;
|
||||
}
|
||||
// this is a shape upcast
|
||||
while(operandShape.size() < shape.size())
|
||||
operandShape.insert(operandShape.begin(), 1);
|
||||
for(size_t i = 0; i < shape.size(); i++) {
|
||||
|
Reference in New Issue
Block a user