[FRONTEND] Made more tests pass (#805)
This commit is contained in:
@@ -441,11 +441,22 @@ void init_triton_ir(py::module &&m) {
|
||||
loc, self.getF32FloatAttr(v));
|
||||
})
|
||||
.def("get_null_value",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (type.isa<mlir::FloatType>())
|
||||
return self.create<mlir::arith::ConstantOp>(
|
||||
loc, self.getF32FloatAttr(0.0));
|
||||
if (auto floatTy = type.dyn_cast<mlir::FloatType>())
|
||||
return self.create<mlir::arith::ConstantFloatOp>(
|
||||
loc, mlir::APFloat(floatTy.getFloatSemantics(), 0), floatTy);
|
||||
else if (auto intTy = type.dyn_cast<mlir::IntegerType>())
|
||||
return self.create<mlir::arith::ConstantIntOp>(loc, 0, intTy);
|
||||
else
|
||||
throw std::runtime_error("Not implemented");
|
||||
})
|
||||
.def("get_all_ones_value",
|
||||
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
uint64_t val = 0xFFFFFFFFFFFFFFFF;
|
||||
if (auto intTy = type.dyn_cast<mlir::IntegerType>())
|
||||
return self.create<mlir::arith::ConstantIntOp>(loc, val, intTy);
|
||||
else
|
||||
throw std::runtime_error("Not implemented");
|
||||
})
|
||||
@@ -602,7 +613,7 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
|
||||
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
@@ -1143,6 +1154,18 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
|
||||
operand, axis);
|
||||
})
|
||||
.def("create_ptr_to_int",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val,
|
||||
mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::PtrToIntOp>(loc, type, val);
|
||||
})
|
||||
.def("create_int_to_ptr",
|
||||
[](mlir::OpBuilder &self, mlir::Value &val,
|
||||
mlir::Type &type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::IntToPtrOp>(loc, type, val);
|
||||
})
|
||||
.def("create_select",
|
||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||
@@ -1231,7 +1254,6 @@ void init_triton_ir(py::module &&m) {
|
||||
}
|
||||
|
||||
void init_triton_translation(py::module &m) {
|
||||
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
|
||||
|
Reference in New Issue
Block a user