From bb2d98ce4baa87e77889389605a754620844f593 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 13 May 2020 20:44:09 -0400 Subject: [PATCH] [LANG] Added support for flattening --- lib/lang/ast.cc | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 7c43c78f4..18bfbec71 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -749,10 +749,21 @@ void UnaryOp::CastOpTypeChecking() { if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1) Error(this, "tile with more than one element cannot be casted to scalar"); if(type_->IsTile() && operandType->IsTile()){ - auto shape = type_->ToTile()->Shape(); auto operandShape = operandType->ToTile()->Shape(); - if(operandShape.size() > shape.size()) - Error(this, "cast cannot reduce operand rank"); + auto shape = type_->ToTile()->Shape(); + // this is a shape downcast + if(operandShape.size() > shape.size()){ + size_t operandNumel = 1; + size_t numel = 1; + for(auto x: operandShape) + operandNumel *= x; + for(auto x: shape) + numel *= x; + if(operandNumel != numel) + Error(this, "cast cannot change number of elements"); + return; + } + // this is a shape upcast while(operandShape.size() < shape.size()) operandShape.insert(operandShape.begin(), 1); for(size_t i = 0; i < shape.size(); i++) {