[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
namespace mlir {
namespace triton {
//-- FpToFpOp --
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
::mlir::TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
auto srcEltType = inputs.front();
auto dstEltType = outputs.front();
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
if (srcTensorType && dstTensorType) {
srcEltType = srcTensorType.getElementType();
dstEltType = dstTensorType.getElementType();
}
// Check whether fp8 <=> fp16, bf16, f32, f64
// Make `srcEltType` always the fp8 side
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
std::swap(srcEltType, dstEltType);
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
return false;
return dstEltType.isF16() || dstEltType.isBF16() ||
dstEltType.isF32() || dstEltType.isF64();
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {