diff --git a/include/triton/lang/type.h b/include/triton/lang/type.h index 59ea8eb3f..8b63b401c 100644 --- a/include/triton/lang/type.h +++ b/include/triton/lang/type.h @@ -331,7 +331,6 @@ public: using ShapeInt = std::vector; public: - static TileType* New(const ShapeExpr& expr, QualType eleType); static TileType* New(const ShapeInt& shape, QualType eleType); virtual ~TileType() { } @@ -345,6 +344,7 @@ public: } ShapeInt Shape() { return shape_; } + int NumEle() const { int ret = 1; for(int s: shape_) @@ -352,24 +352,13 @@ public: return ret; } -protected: - TileType(MemPool* pool, const ShapeExpr& expr, QualType derived) - : DerivedType(pool, derived), - shapeExpr_(expr) { - bool isComplete = true; - for(Expr* s: shapeExpr_) - isComplete = isComplete && !s; - SetComplete(isComplete); + bool CheckPow2NumEl() const { + int n = NumEle(); + return n && !(n & (n - 1)); } - TileType(MemPool* pool, const ShapeInt& shape, QualType derived) - : DerivedType(pool, derived), - shape_(shape) { - bool isComplete = true; - for(int s: shape_) - isComplete = isComplete && (s>=0); - SetComplete(isComplete); - } +protected: + TileType(MemPool* pool, const ShapeInt& shape, QualType derived); protected: ShapeExpr shapeExpr_; diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 6aaf13408..62c6a0b6c 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -448,7 +448,10 @@ void BinaryOp::RangeOpTypeChecking() { int len = static_cast(end - begin); if(len < 0) Error(this, "range cannot be negative"); - type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type()); + TileType* ret = TileType::New(TileType::ShapeInt{len}, lhs_->Type()); + if(!ret->CheckPow2NumEl()) + Error(this, "range must have power of 2 number of elements"); + type_ = ret; } void BinaryOp::MaskedDerefOpTypeChecking() { @@ -751,6 +754,8 @@ void UnaryOp::CastOpTypeChecking() { if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1) Error(this, "tile with more than one element cannot be casted to scalar"); if(type_->IsTile() && operandType->IsTile()){ + if(!type_->ToTile()->CheckPow2NumEl()) + Error(this, "tile must have power of 2 number of elements"); auto operandShape = operandType->ToTile()->Shape(); auto shape = type_->ToTile()->Shape(); // this is a shape downcast diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index 2ab65d150..8f6ad617f 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -1787,8 +1787,10 @@ QualType Parser::ParseArrayFuncDeclarator(const Token* ident, QualType base) { Error(ident, "'%s' has incomplete element type", ident->str_.c_str()); } // return a pointer for tiles in constant memory: - return TileType::New(shape, base); - + TileType* ret = TileType::New(shape, base); + if(!ret->CheckPow2NumEl()) + Error(ts_.Peek(), "tile must have power of 2 number of elements"); + return ret; } else if (ts_.Try('(')) { // Function declaration if (base->ToFunc()) { Error(ts_.Peek(), diff --git a/lib/lang/type.cc b/lib/lang/type.cc index 13d09cf89..dc0b65125 100644 --- a/lib/lang/type.cc +++ b/lib/lang/type.cc @@ -110,11 +110,6 @@ ArrayType* ArrayType::New(Expr* expr, QualType eleType) { ArrayType(&arrayTypePool, expr, eleType); } -TileType* TileType::New(const ShapeExpr &expr, QualType eleType) { - return new (tileTypePool.Alloc()) - TileType(&tileTypePool, expr, eleType); -} - TileType* TileType::New(const ShapeInt &shape, QualType eleType) { return new (tileTypePool.Alloc()) TileType(&tileTypePool, shape, eleType); @@ -316,6 +311,15 @@ bool ArrayType::Compatible(const Type& other) const { return true; } +TileType::TileType(MemPool* pool, const ShapeInt& shape, QualType derived) + : DerivedType(pool, derived), + shape_(shape) { + bool isComplete = true; + for(int s: shape_) + isComplete = isComplete && (s>=0); + SetComplete(isComplete); +} + bool TileType::Compatible(const Type& other) const { // For two tile type to be compatible, // the element types must be compatible