[Triton-MLIR] Support FP8 (#864)
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::Float8Type>();
|
||||
})
|
||||
.def("get_bf8_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::BFloat8Type>();
|
||||
})
|
||||
.def(
|
||||
"get_half_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
|
||||
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Cast instructions
|
||||
// Conversions for custom FP types (FP8)
|
||||
.def("create_fp_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
|
||||
})
|
||||
// Conversions for standard LLVM builtin types
|
||||
.def("create_bitcast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
.def("create_si_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
})
|
||||
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
|
Reference in New Issue
Block a user