[TritonMLIR] Disallow 0D tensor (#788)

This commit is contained in:
Shintaro Iwasaki
2022-10-19 10:34:32 -07:00
committed by GitHub
parent 4464646efb
commit 0d22d2bc03
8 changed files with 203 additions and 87 deletions

View File

@@ -1104,8 +1104,11 @@ void init_triton_ir(py::module &&m) {
operand.getType().dyn_cast<mlir::RankedTensorType>();
std::vector<int64_t> shape = inputTensorType.getShape();
shape.erase(shape.begin() + axis);
auto resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
mlir::Type resType = inputTensorType.getElementType();
if (!shape.empty()) {
resType = mlir::RankedTensorType::get(
shape, inputTensorType.getElementType());
}
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
})