[FRONTEND][BACKEND] Fixed various bugs (#819)

- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
This commit is contained in:
Philippe Tillet
2022-10-28 23:34:14 -07:00
committed by GitHub
parent 82834d34f9
commit 7dfab26a39
5 changed files with 74 additions and 52 deletions

View File

@@ -170,9 +170,8 @@ void init_triton_ir(py::module &&m) {
.def("replace_all_uses_with",
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
});
;
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::Region>(m, "region")
@@ -660,13 +659,13 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
// get element type if necessary
mlir::Type srcType = src.getType();
auto srcTensorType = srcType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstType.dyn_cast<mlir::RankedTensorType>();
mlir::Type srcEltType = srcType;
mlir::Type dstEltType = dstType;
if (dstType.isa<mlir::RankedTensorType>()) {
dstEltType =
dstType.cast<mlir::RankedTensorType>().getElementType();
srcEltType =
srcType.cast<mlir::RankedTensorType>().getElementType();
if (dstTensorType && srcTensorType) {
dstEltType = dstTensorType.getElementType();
srcEltType = srcTensorType.getElementType();
}
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();