[Triton-MLIR] Support FP8 (#864)

Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Chenggang Zhao
2022-11-10 15:53:06 +08:00
committed by GitHub
parent 4946167241
commit 57fd1864a7
18 changed files with 571 additions and 160 deletions

View File

@@ -22,8 +22,8 @@ struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building a ASM program, the objective of PTXBuilder is to give a
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// A helper for building an ASM program, the objective of PTXBuilder is to give
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
@@ -147,7 +147,7 @@ struct PTXBuilder {
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int v);
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
@@ -172,6 +172,22 @@ private:
return argArchive.back().get();
}
// Make the oprands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
auto ida = std::find(order.begin(), order.end(), a.get());
auto idb = std::find(order.begin(), order.end(), b.get());
assert(ida != order.end());
assert(idb != order.end());
return ida < idb;
});
}
friend struct PTXInstr;
friend struct PTXInstrCommon;

View File

@@ -10,6 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
@@ -72,17 +73,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8, BF8).
Floating point casting for custom types (F8).
F8 <-> BF8, FP16, FP32
BF8 <-> F8, FP16, FP32
F8 <-> FP16, BF16, FP32, FP64
}];
let arguments = (ins TT_FloatLike:$from);

View File

@@ -14,9 +14,8 @@ class TritonTypeDef<string name, string _mnemonic>
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;