[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -15,7 +15,7 @@ static Type getI1SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
@@ -23,7 +23,7 @@ static Type getI32SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeFromTensor(Type type) {
|
||||
|
Reference in New Issue
Block a user