[FRONTEND] Enhanced support for binary operators (#801)

Disabled modulo test (due to change in behavior for `frem` in nvptx
between llvm-11 and llvm-14) and bfloat16 (will require some work to
emulate in software similar to how it's done in `master`)
This commit is contained in:
Philippe Tillet
2022-10-24 19:47:01 -07:00
committed by GitHub
parent fcb228d1d4
commit a2cbe7af91
4 changed files with 1696 additions and 18 deletions

View File

@@ -642,8 +642,31 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
})
// .def("create_int_cast", &ir::builder::create_int_cast)
// .def("create_downcast", &ir::builder::create_downcast)
.def("create_int_cast",
[](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType,
bool isSigned) -> mlir::Value {
auto loc = self.getUnknownLoc();
// get element type if necessary
mlir::Type srcType = src.getType();
mlir::Type srcEltType = srcType;
mlir::Type dstEltType = dstType;
if (dstType.isa<mlir::RankedTensorType>()) {
dstEltType =
dstType.cast<mlir::RankedTensorType>().getElementType();
srcEltType =
srcType.cast<mlir::RankedTensorType>().getElementType();
}
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
if (srcWidth == dstWidth)
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
else if (srcWidth > dstWidth)
return self.create<mlir::arith::TruncIOp>(loc, dstType, src);
else if (isSigned)
return self.create<mlir::arith::ExtSIOp>(loc, dstType, src);
else
return self.create<mlir::arith::ExtUIOp>(loc, dstType, src);
})
.def("create_to_index",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();

1552
python/tests/test_core.py Normal file

File diff suppressed because it is too large Load Diff