[LANG] Added support for flattening
This commit is contained in:
committed by
Philippe Tillet
parent
694bfbddf9
commit
bb2d98ce4b
@@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() {
|
|||||||
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
||||||
Error(this, "tile with more than one element cannot be casted to scalar");
|
Error(this, "tile with more than one element cannot be casted to scalar");
|
||||||
if(type_->IsTile() && operandType->IsTile()){
|
if(type_->IsTile() && operandType->IsTile()){
|
||||||
auto shape = type_->ToTile()->Shape();
|
|
||||||
auto operandShape = operandType->ToTile()->Shape();
|
auto operandShape = operandType->ToTile()->Shape();
|
||||||
if(operandShape.size() > shape.size())
|
auto shape = type_->ToTile()->Shape();
|
||||||
Error(this, "cast cannot reduce operand rank");
|
// this is a shape downcast
|
||||||
|
if(operandShape.size() > shape.size()){
|
||||||
|
size_t operandNumel = 1;
|
||||||
|
size_t numel = 1;
|
||||||
|
for(auto x: operandShape)
|
||||||
|
operandNumel *= x;
|
||||||
|
for(auto x: shape)
|
||||||
|
numel *= x;
|
||||||
|
if(operandNumel != numel)
|
||||||
|
Error(this, "cast cannot change number of elements");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// this is a shape upcast
|
||||||
while(operandShape.size() < shape.size())
|
while(operandShape.size() < shape.size())
|
||||||
operandShape.insert(operandShape.begin(), 1);
|
operandShape.insert(operandShape.begin(), 1);
|
||||||
for(size_t i = 0; i < shape.size(); i++) {
|
for(size_t i = 0; i < shape.size(); i++) {
|
||||||
|
Reference in New Issue
Block a user