More progress on TritonGPU conversion

This commit is contained in:
Yan Da
2022-05-04 14:54:31 +08:00
parent 3ad7bee35e
commit b9279d2e3b
4 changed files with 48 additions and 26 deletions

View File

@@ -13,14 +13,14 @@ namespace triton {
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type);
return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding());
return Type();
}
@@ -29,7 +29,7 @@ static Type getPointerTypeFromTensor(Type type) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType);
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
}
return Type();
}