[TritonMLIR] Disallow 0D tensor (#788)
This commit is contained in:
@@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) {
|
||||
operand.getType().dyn_cast<mlir::RankedTensorType>();
|
||||
std::vector<int64_t> shape = inputTensorType.getShape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
mlir::Type resType = inputTensorType.getElementType();
|
||||
if (!shape.empty()) {
|
||||
resType = mlir::RankedTensorType::get(
|
||||
shape, inputTensorType.getElementType());
|
||||
}
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
})
|
||||
|
Reference in New Issue
Block a user