[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::Float8Type>();
})
.def("get_bf8_ty",
[](mlir::OpBuilder &self) -> mlir::Type {
return self.getType<mlir::triton::BFloat8Type>();
})
.def(
"get_half_ty",
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
})
// Cast instructions
// Conversions for custom FP types (FP8)
.def("create_fp_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
})
// Conversions for standard LLVM builtin types
.def("create_bitcast",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
})
// .def("create_cast", &ir::builder::create_cast)
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
.def("create_si_to_fp",
[](mlir::OpBuilder &self, mlir::Value &src,
mlir::Type &dstType) -> mlir::Value {
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
return self.create<mlir::arith::IndexCastOp>(loc, input,
self.getI32Type());
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
mlir::Value &rhs) -> mlir::Value {