[FRONTEND] Enhanced support for binary operators (#801)
Disabled modulo test (due to change in behavior for `frem` in nvptx between llvm-11 and llvm-14) and bfloat16 (will require some work to emulate in software similar to how it's done in `master`)
This commit is contained in:
@@ -642,8 +642,31 @@ void init_triton_ir(py::module &&m) {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::TruncFOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_int_cast", &ir::builder::create_int_cast)
|
||||
// .def("create_downcast", &ir::builder::create_downcast)
|
||||
.def("create_int_cast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType,
|
||||
bool isSigned) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
// get element type if necessary
|
||||
mlir::Type srcType = src.getType();
|
||||
mlir::Type srcEltType = srcType;
|
||||
mlir::Type dstEltType = dstType;
|
||||
if (dstType.isa<mlir::RankedTensorType>()) {
|
||||
dstEltType =
|
||||
dstType.cast<mlir::RankedTensorType>().getElementType();
|
||||
srcEltType =
|
||||
srcType.cast<mlir::RankedTensorType>().getElementType();
|
||||
}
|
||||
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
|
||||
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
|
||||
if (srcWidth == dstWidth)
|
||||
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
|
||||
else if (srcWidth > dstWidth)
|
||||
return self.create<mlir::arith::TruncIOp>(loc, dstType, src);
|
||||
else if (isSigned)
|
||||
return self.create<mlir::arith::ExtSIOp>(loc, dstType, src);
|
||||
else
|
||||
return self.create<mlir::arith::ExtUIOp>(loc, dstType, src);
|
||||
})
|
||||
.def("create_to_index",
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
1552
python/tests/test_core.py
Normal file
1552
python/tests/test_core.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user