[TritonMLIR] Disallow 0D tensor (#788)
This commit is contained in:
@@ -26,14 +26,15 @@ static Type getI32SameShape(Type type) {
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeFromTensor(Type type) {
|
||||
static Type getPointerTypeSameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Type elementType = tensorType.getElementType();
|
||||
auto shape = tensorType.getShape();
|
||||
PointerType ptrType = PointerType::get(elementType, 1);
|
||||
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
||||
} else {
|
||||
return PointerType::get(type, 1);
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
@@ -49,7 +50,7 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
|
||||
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 2) {
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
@@ -92,8 +93,8 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeFromTensor(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||
|
||||
@@ -194,26 +195,33 @@ mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
// infer shape
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||
return mlir::failure();
|
||||
if (retShape.empty()) {
|
||||
// 0d-tensor -> scalar
|
||||
inferredReturnTypes.push_back(argEltTy);
|
||||
} else {
|
||||
// nd-tensor where n >= 1
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface =
|
||||
dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user