[FRONTEND] Made more tests pass (#805)

This commit is contained in:
Philippe Tillet
2022-10-26 17:47:33 -07:00
committed by GitHub
parent bb7008651a
commit 3e6cc6d66c
9 changed files with 303 additions and 166 deletions

View File

@@ -441,11 +441,22 @@ void init_triton_ir(py::module &&m) {
loc, self.getF32FloatAttr(v));
})
.def("get_null_value",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (type.isa<mlir::FloatType>())
return self.create<mlir::arith::ConstantOp>(
loc, self.getF32FloatAttr(0.0));
if (auto floatTy = type.dyn_cast<mlir::FloatType>())
return self.create<mlir::arith::ConstantFloatOp>(
loc, mlir::APFloat(floatTy.getFloatSemantics(), 0), floatTy);
else if (auto intTy = type.dyn_cast<mlir::IntegerType>())
return self.create<mlir::arith::ConstantIntOp>(loc, 0, intTy);
else
throw std::runtime_error("Not implemented");
})
.def("get_all_ones_value",
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
auto loc = self.getUnknownLoc();
uint64_t val = 0xFFFFFFFFFFFFFFFF;
if (auto intTy = type.dyn_cast<mlir::IntegerType>())
return self.create<mlir::arith::ConstantIntOp>(loc, val, intTy);
else
throw std::runtime_error("Not implemented");
})
@@ -602,7 +613,7 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::BitcastOp>(loc, dstType, src);
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
@@ -1143,6 +1154,18 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::triton::ReduceOp>(loc, resType, redOp,
operand, axis);
})
.def("create_ptr_to_int",
[](mlir::OpBuilder &self, mlir::Value &val,
mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::PtrToIntOp>(loc, type, val);
})
.def("create_int_to_ptr",
[](mlir::OpBuilder &self, mlir::Value &val,
mlir::Type &type) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::IntToPtrOp>(loc, type, val);
})
.def("create_select",
[](mlir::OpBuilder &self, mlir::Value &condition,
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
@@ -1231,7 +1254,6 @@ void init_triton_ir(py::module &&m) {
}
void init_triton_translation(py::module &m) {
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {