[Triton-MLIR] Support FP8 (#864)
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1,12 +1,20 @@
|
||||
# Triton builds
|
||||
build/
|
||||
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
|
||||
# Triton Python module builds
|
||||
python/build/
|
||||
python/triton.egg-info/
|
||||
python/triton/_C/libtriton.pyd
|
||||
python/triton/_C/libtriton.so
|
||||
|
||||
# Python caches
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
|
||||
# VS Code project files
|
||||
.vscode
|
||||
.vs
|
||||
|
||||
# JetBrains project files
|
||||
.idea
|
||||
cmake-build-*
|
||||
|
@@ -22,8 +22,8 @@ struct PTXInstrExecution;
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building a ASM program, the objective of PTXBuilder is to give a
|
||||
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// A helper for building an ASM program, the objective of PTXBuilder is to give
|
||||
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
@@ -147,7 +147,7 @@ struct PTXBuilder {
|
||||
Operand *newOperand(StringRef constraint);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int v);
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
@@ -172,6 +172,22 @@ private:
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
// Make the oprands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
|
||||
// determined by PTX code snippet passed from external.
|
||||
sort(argArchive.begin(), argArchive.end(),
|
||||
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||
auto ida = std::find(order.begin(), order.end(), a.get());
|
||||
auto idb = std::find(order.begin(), order.end(), b.get());
|
||||
assert(ida != order.end());
|
||||
assert(idb != order.end());
|
||||
return ida < idb;
|
||||
});
|
||||
}
|
||||
|
||||
friend struct PTXInstr;
|
||||
friend struct PTXInstrCommon;
|
||||
|
||||
|
@@ -10,6 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
@@ -72,17 +73,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
// TODO: Add verifier
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
Floating point casting for custom types (F8, BF8).
|
||||
Floating point casting for custom types (F8).
|
||||
|
||||
F8 <-> BF8, FP16, FP32
|
||||
BF8 <-> F8, FP16, FP32
|
||||
F8 <-> FP16, BF16, FP32, FP64
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatLike:$from);
|
||||
|
@@ -14,9 +14,8 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
|
@@ -45,7 +45,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) {
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) {
|
||||
PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) {
|
||||
std::stringstream ss;
|
||||
ss << "0x" << std::hex << v;
|
||||
return newConstantOperand(ss.str());
|
||||
@@ -130,8 +130,18 @@ std::string PTXBuilder::dump() const {
|
||||
|
||||
PTXInstrExecution &PTXInstrCommon::call(ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs) {
|
||||
if (onlyAttachMLIRArgs) {
|
||||
// Nearly impossible to make the $0,$1 in two PTX code snippets to point to
|
||||
// the same MLIR values in onlyAttachMLIRArgs mode.
|
||||
assert(builder->executions.empty() &&
|
||||
"builder can only hold a single execution when onlyAttachMIIRArgs "
|
||||
"is true.");
|
||||
builder->reorderArgArchive(oprs);
|
||||
}
|
||||
|
||||
builder->executions.emplace_back(
|
||||
std::make_unique<PTXInstrExecution>(this, oprs, onlyAttachMLIRArgs));
|
||||
|
||||
return *builder->executions.back();
|
||||
}
|
||||
|
||||
|
@@ -64,9 +64,8 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create a index type constant.
|
||||
// Create an index type constant.
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
|
||||
TypeConverter *converter, int64_t value) {
|
||||
Type ty = converter->convertType(builder.getIndexType());
|
||||
return builder.create<LLVM::ConstantOp>(loc, ty,
|
||||
@@ -127,6 +126,7 @@ void llPrintf(StringRef msg, ValueRange args,
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
@@ -339,7 +339,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
}
|
||||
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (auto v : llvm::enumerate(resultVals)) {
|
||||
for (const auto& v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
||||
rewriter.getI64ArrayAttr(v.index()));
|
||||
}
|
||||
@@ -699,7 +700,7 @@ public:
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization. We might
|
||||
// implement a indiceCache if necessary.
|
||||
// implement an indexCache if necessary.
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blockedLayout,
|
||||
@@ -953,7 +954,8 @@ struct LoadOpConversion
|
||||
|
||||
// Determine the vectorization size
|
||||
Type valueTy = op.getResult().getType();
|
||||
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
||||
Type valueElemTy = typeConverter->convertType(
|
||||
getElementTypeOrSelf(valueTy));
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
if (llMask)
|
||||
@@ -1147,7 +1149,8 @@ struct StoreOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto valueTy = value.getType();
|
||||
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
||||
Type valueElemTy = typeConverter->convertType(
|
||||
getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
@@ -1734,8 +1737,8 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We cannot directly
|
||||
// rewriter.replaceOp(op, adaptor.src());
|
||||
// We cannot directly run
|
||||
// `rewriter.replaceOp(op, adaptor.src())`
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
@@ -2096,6 +2099,330 @@ struct ExtractSliceOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct FpToFpOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp16x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 1; \n"
|
||||
"shr.b32 b1, b1, 1; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2x2StructTy =
|
||||
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
|
||||
auto fp16x2x2Struct =
|
||||
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
|
||||
auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({0}));
|
||||
auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({1}));
|
||||
return {
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
|
||||
extract_element(f16_ty, fp16x2Vec1, i32_val(1))
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp16x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
|
||||
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToBf16x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
Value fp8x4Vec = undef(fp8x4VecTy);
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
|
||||
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"and.b32 sign0, a0, 0x80008000; \n"
|
||||
"and.b32 sign1, a1, 0x80008000; \n"
|
||||
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||
"shr.b32 nosign0, nosign0, 4; \n"
|
||||
"shr.b32 nosign1, nosign1, 4; \n"
|
||||
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"or.b32 $0, sign0, nosign0; \n"
|
||||
"or.b32 $1, sign1, nosign1; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
||||
auto bf16x2x2StructTy =
|
||||
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
|
||||
auto bf16x2x2Struct =
|
||||
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
|
||||
auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({0}));
|
||||
auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct,
|
||||
rewriter.getI32ArrayAttr({1}));
|
||||
return {
|
||||
extract_element(bf16_ty, bf16x2Vec0, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec0, i32_val(1)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(0)),
|
||||
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertBf16x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
|
||||
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
|
||||
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
|
||||
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
|
||||
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm =
|
||||
"{ \n"
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
||||
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
||||
"mov.u32 fp8_min, 0x38003800; \n"
|
||||
"mov.u32 fp8_max, 0x3ff03ff0; \n"
|
||||
"mov.u32 rn_, 0x80008; \n"
|
||||
"mov.u32 zero, 0; \n"
|
||||
"and.b32 sign0, $1, 0x80008000; \n"
|
||||
"and.b32 sign1, $2, 0x80008000; \n"
|
||||
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
|
||||
"and.b32 nosign0, $1, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, $2, 0x7fff7fff; \n"
|
||||
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
|
||||
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
|
||||
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
|
||||
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
|
||||
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
|
||||
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
|
||||
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
|
||||
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
|
||||
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
|
||||
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
|
||||
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
|
||||
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
|
||||
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
|
||||
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
|
||||
"add.u32 nosign0, nosign0, rn_; \n"
|
||||
"add.u32 nosign1, nosign1, rn_; \n"
|
||||
"sub.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"sub.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"shr.u32 nosign0, nosign0, 4; \n"
|
||||
"shr.u32 nosign1, nosign1, 4; \n"
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
|
||||
"or.b32 $0, nosign, sign; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
|
||||
call({o, i0, i1}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp32x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp32x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp8x4ToFp64x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])
|
||||
};
|
||||
}
|
||||
|
||||
static SmallVector<Value> convertFp64x4ToFp8x4(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value& v0, const Value& v1, const Value& v2, const Value& v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
|
||||
auto srcEltType = srcTensorType.getElementType();
|
||||
auto dstEltType = dstTensorType.getElementType();
|
||||
assert(srcEltType.isa<triton::Float8Type>() ||
|
||||
dstEltType.isa<triton::Float8Type>());
|
||||
auto convertedDstTensorType =
|
||||
this->getTypeConverter()->convertType(dstTensorType);
|
||||
auto convertedDstEleType =
|
||||
this->getTypeConverter()->convertType(dstEltType);
|
||||
|
||||
// Select convertor
|
||||
std::function<SmallVector<Value>(Location, ConversionPatternRewriter&,
|
||||
const Value&, const Value&,
|
||||
const Value&, const Value&)> convertor;
|
||||
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
|
||||
convertor = convertFp8x4ToFp16x4;
|
||||
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp16x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
|
||||
convertor = convertFp8x4ToBf16x4;
|
||||
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertBf16x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
|
||||
convertor = convertFp8x4ToFp32x4;
|
||||
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp32x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
|
||||
convertor = convertFp8x4ToFp64x4;
|
||||
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp64x4ToFp8x4;
|
||||
} else {
|
||||
assert(false && "unsupported type casting");
|
||||
}
|
||||
|
||||
// Vectorized casting
|
||||
auto loc = op->getLoc();
|
||||
auto elems = getElemsPerThread(dstTensorType);
|
||||
assert(elems % 4 == 0 &&
|
||||
"FP8 casting only support tensors with 4-aligned sizes");
|
||||
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < elems; i += 4) {
|
||||
auto converted = convertor(loc, rewriter,
|
||||
elements[i], elements[i + 1],
|
||||
elements[i + 2], elements[i + 3]);
|
||||
resultVals.append(converted);
|
||||
}
|
||||
assert(resultVals.size() == elems);
|
||||
auto result = getStructFromElements(loc, resultVals, rewriter,
|
||||
convertedDstTensorType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// A CRTP style of base class.
|
||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||
class ElementwiseOpConversionBase
|
||||
@@ -2309,7 +2636,7 @@ private:
|
||||
Value smemBase) const;
|
||||
|
||||
// blocked/mma -> blocked/mma.
|
||||
// Data padding in shared memory to avoid bank confict.
|
||||
// Data padding in shared memory to avoid bank conflict.
|
||||
LogicalResult
|
||||
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
@@ -3265,7 +3592,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
int NK = shape[1];
|
||||
unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]);
|
||||
|
||||
// NOTE We cound't get the vec from the shared layout.
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecA = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
@@ -3283,7 +3610,7 @@ struct DotOpMmaV1ConversionHelper {
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
SmallVector<int> spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0
|
||||
// NOTE We cound't get the vec from the shared layout.
|
||||
// NOTE: We couldn't get the vec from the shared layout.
|
||||
// int vecB = sharedLayout.getVec();
|
||||
// TODO[Superjomn]: Consider the case when vecA > 4
|
||||
bool vecGt4 = false;
|
||||
@@ -3387,7 +3714,7 @@ struct DotOpMmaV2ConversionHelper {
|
||||
return Type{};
|
||||
}
|
||||
|
||||
// The type of a matrix that loaded by either a ldmatrix or composed lds.
|
||||
// The type of matrix that loaded by either a ldmatrix or composed lds.
|
||||
Type getMatType() const {
|
||||
Type fp32Ty = type::f32Ty(ctx);
|
||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||
@@ -3583,7 +3910,7 @@ private:
|
||||
"mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
};
|
||||
|
||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||
// vector length per ldmatrix (16*8/element_size_in_bits)
|
||||
inline static const std::map<TensorCoreType, uint8_t> mmaInstrVec = {
|
||||
{TensorCoreType::FP32_FP16_FP16_FP32, 8},
|
||||
{TensorCoreType::FP32_BF16_BF16_FP32, 8},
|
||||
@@ -3723,7 +4050,7 @@ struct MMA16816ConversionHelper {
|
||||
// load from smem
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
@@ -3754,7 +4081,7 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
@@ -4713,14 +5040,15 @@ public:
|
||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
// Internally store float8 as int8
|
||||
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
}
|
||||
|
||||
Type convertTritonPointerType(triton::PointerType type) {
|
||||
return LLVM::LLVMPointerType::get(type.getPointeeType(),
|
||||
// Recursively translate pointee type
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
@@ -4753,7 +5081,7 @@ public:
|
||||
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type);
|
||||
size_t fcSize = 4 * repM * repN;
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fcSize, type.getElementType()));
|
||||
ctx, SmallVector<Type>(fcSize, convertType(type.getElementType())));
|
||||
}
|
||||
|
||||
if (mmaLayout.getVersion() == 1) {
|
||||
@@ -4762,7 +5090,7 @@ public:
|
||||
int repN = helper.getRepN(shape[1]);
|
||||
int elems = 8 * repM * repN;
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems, type.getElementType()));
|
||||
ctx, SmallVector<Type>(elems, convertType(type.getElementType())));
|
||||
}
|
||||
|
||||
llvm::errs()
|
||||
@@ -4773,7 +5101,7 @@ public:
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
auto mmaLayout = dot_op_layout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = type.getElementType();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
auto vecSize = 1;
|
||||
if (elemTy.getIntOrFloatBitWidth() == 16) {
|
||||
vecSize = 2;
|
||||
@@ -5324,6 +5652,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||
#undef POPULATE_UNARY_OP
|
||||
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||
@@ -5369,12 +5699,12 @@ public:
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// step 1: Allocate shared memories and insert barriers
|
||||
// setp 2: Convert SCF to CFG
|
||||
// step 2: Convert SCF to CFG
|
||||
// step 3: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// step 4: Convert the rest of ops via partial conversion
|
||||
// The reason for putting step 1 before step 2 is that the membar analysis
|
||||
// currently only supports SCF but not CFG.
|
||||
// The reason for a seperation between 1/4 is that, step 3 is out of
|
||||
// The reason for a separation between 1/4 is that, step 3 is out of
|
||||
// the scope of Dialect Conversion, thus we need to make sure the smem
|
||||
// is not revised during the conversion of step 4.
|
||||
Allocation allocation(mod);
|
||||
|
@@ -371,6 +371,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::BitcastOp>,
|
||||
TritonGenericPattern<triton::FpToFpOp>,
|
||||
TritonGenericPattern<triton::IntToPtrOp>,
|
||||
TritonGenericPattern<triton::PtrToIntOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
|
@@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- FpToFpOp --
|
||||
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
|
||||
::mlir::TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
auto srcEltType = inputs.front();
|
||||
auto dstEltType = outputs.front();
|
||||
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
if (srcTensorType && dstTensorType) {
|
||||
srcEltType = srcTensorType.getElementType();
|
||||
dstEltType = dstTensorType.getElementType();
|
||||
}
|
||||
// Check whether fp8 <=> fp16, bf16, f32, f64
|
||||
// Make `srcEltType` always the fp8 side
|
||||
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
std::swap(srcEltType, dstEltType);
|
||||
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
return false;
|
||||
return dstEltType.isF16() || dstEltType.isBF16() ||
|
||||
dstEltType.isF32() || dstEltType.isF64();
|
||||
}
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
|
@@ -44,7 +44,9 @@ namespace gpu {
|
||||
|
||||
// TODO: Inheritation of layout attributes
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto layout = tensorType.getEncoding();
|
||||
|
@@ -32,7 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
// Thread tile size depends on memory alignment
|
||||
SmallVector<unsigned, 4> sizePerThread(rank, 1);
|
||||
PointerType ptrType = origType.getElementType().cast<PointerType>();
|
||||
unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth();
|
||||
auto pointeeType = ptrType.getPointeeType();
|
||||
unsigned numBits =
|
||||
pointeeType.isa<triton::Float8Type>() ?
|
||||
8 : pointeeType.getIntOrFloatBitWidth();
|
||||
unsigned maxMultiple = info.getDivisibility(order[0]);
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
|
@@ -140,7 +140,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
|
@@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::Float8Type>();
|
||||
})
|
||||
.def("get_bf8_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::BFloat8Type>();
|
||||
})
|
||||
.def(
|
||||
"get_half_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); })
|
||||
@@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
|
||||
// Cast instructions
|
||||
// Conversions for custom FP types (FP8)
|
||||
.def("create_fp_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::FpToFpOp>(loc, dstType, src);
|
||||
})
|
||||
// Conversions for standard LLVM builtin types
|
||||
.def("create_bitcast",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::BitcastOp>(loc, dstType, src);
|
||||
})
|
||||
// .def("create_cast", &ir::builder::create_cast)
|
||||
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
||||
.def("create_si_to_fp",
|
||||
[](mlir::OpBuilder &self, mlir::Value &src,
|
||||
mlir::Type &dstType) -> mlir::Value {
|
||||
@@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::IndexCastOp>(loc, input,
|
||||
self.getI32Type());
|
||||
})
|
||||
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
mlir::Value &rhs) -> mlir::Value {
|
||||
|
@@ -780,88 +780,88 @@ def test_store_bool():
|
||||
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
# def test_f8_xf16_roundtrip(dtype):
|
||||
# """Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
# check_type_supported(dtype)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
|
||||
# @triton.jit
|
||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# mask = offsets < n_elements
|
||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
||||
# output = input
|
||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
# f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# n_elements = f8_tensor.numel()
|
||||
# xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# assert torch.all(f8_tensor == f8_output_tensor)
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
# def test_f16_to_f8_rounding():
|
||||
# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
# error is the minimum over all float8.
|
||||
# Or the same explanation a bit mathier:
|
||||
# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
# @triton.jit
|
||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# mask = offsets < n_elements
|
||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
||||
# output = input
|
||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
||||
def test_f16_to_f8_rounding():
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
# f16_input_np = (
|
||||
# np.array(
|
||||
# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
# )
|
||||
# .view(np.float16)
|
||||
# )
|
||||
# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
# n_elements = f16_input.numel()
|
||||
# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
# f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
|
||||
f16_input_np = (
|
||||
np.array(
|
||||
range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16,
|
||||
)
|
||||
.view(np.float16)
|
||||
)
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
f16_output = torch.empty_like(f16_input, dtype=torch.float16)
|
||||
copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# abs_error = torch.abs(f16_input - f16_output)
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
# torch.isfinite(all_f8_vals_in_f16)
|
||||
# ]
|
||||
all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[
|
||||
torch.isfinite(all_f8_vals_in_f16)
|
||||
]
|
||||
|
||||
# min_error = torch.min(
|
||||
# torch.abs(
|
||||
# f16_input.reshape((-1, 1))
|
||||
# - all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
# ),
|
||||
# dim=1,
|
||||
# )[0]
|
||||
# # 1.9375 is float8 max
|
||||
# mismatch = torch.logical_and(
|
||||
# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
# )
|
||||
# assert torch.all(
|
||||
# torch.logical_not(mismatch)
|
||||
# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||
min_error = torch.min(
|
||||
torch.abs(
|
||||
f16_input.reshape((-1, 1))
|
||||
- all_finite_f8_vals_in_f16.reshape((1, -1))
|
||||
),
|
||||
dim=1,
|
||||
)[0]
|
||||
# 1.9375 is float8 max
|
||||
mismatch = torch.logical_and(
|
||||
abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375)
|
||||
)
|
||||
assert torch.all(
|
||||
torch.logical_not(mismatch)
|
||||
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
|
||||
|
||||
|
||||
# # ---------------
|
||||
|
@@ -48,6 +48,8 @@ class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
CUSTOMIZED_FP_TYPES = ['fp8']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
class SIGNEDNESS(Enum):
|
||||
@@ -129,6 +131,12 @@ class dtype:
|
||||
def is_floating(self):
|
||||
return self.name in dtype.FP_TYPES
|
||||
|
||||
def is_customized_floating(self):
|
||||
return self.name in dtype.CUSTOMIZED_FP_TYPES
|
||||
|
||||
def is_standard_floating(self):
|
||||
return self.name in dtype.STANDARD_FP_TYPES
|
||||
|
||||
def is_int_signed(self):
|
||||
return self.name in dtype.SINT_TYPES
|
||||
|
||||
|
@@ -613,39 +613,45 @@ def cast(input: tl.tensor,
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
|
||||
src_sca_ty = src_ty.scalar
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
# fp8 <=> bf16/fp16
|
||||
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
|
||||
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
|
||||
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
# bf16 <=> (not fp32)
|
||||
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
|
||||
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
|
||||
|
||||
# Casting types of the same bit width: fp16 <=> bf16
|
||||
if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \
|
||||
(src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()):
|
||||
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
||||
|
||||
# FP Truncation
|
||||
# Standard floating types' casting: truncation
|
||||
# fp64 => fp32, fp16, bf16
|
||||
# fp32 => fp16, bf16
|
||||
truncate_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
||||
if truncate_fp:
|
||||
return tl.tensor(builder.create_fp_trunc(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# FP Extension
|
||||
# Standard floating types' casting: extension
|
||||
# fp32 => fp64
|
||||
# fp16 => fp32, fp64
|
||||
# bf16 => fp32, fp64
|
||||
ext_fp = src_sca_ty.is_floating() and \
|
||||
dst_sca_ty.is_floating() and \
|
||||
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
|
||||
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
||||
if ext_fp:
|
||||
return tl.tensor(builder.create_fp_ext(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# Int cast
|
||||
# Casting between integer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
||||
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
||||
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
||||
@@ -658,8 +664,8 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder), sign_extend),
|
||||
dst_ty)
|
||||
|
||||
# Float to Int
|
||||
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
|
||||
# Casting standard floating types to integer types
|
||||
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
||||
if dst_sca_ty.is_bool():
|
||||
ty = input.dtype.to_ir(builder)
|
||||
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
||||
@@ -673,8 +679,8 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# int => float
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
|
||||
# Casting integer types to standard floating types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
||||
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
||||
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
||||
dst_ty.to_ir(builder)),
|
||||
@@ -684,7 +690,7 @@ def cast(input: tl.tensor,
|
||||
dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
# ptr => int
|
||||
# Casting pointer types to integer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
||||
bitwidth = dst_sca_ty.int_bitwidth
|
||||
if bitwidth == 64:
|
||||
@@ -695,19 +701,14 @@ def cast(input: tl.tensor,
|
||||
tl.tensor(builder.get_int64(0), tl.int64),
|
||||
builder)
|
||||
|
||||
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
# Casting integer types to pointer types
|
||||
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# Ptr . Ptr
|
||||
|
||||
# Casting pointer types to pointer types
|
||||
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
||||
# * . Bool
|
||||
if dst_sca_ty.is_bool():
|
||||
if src_sca_ty.is_ptr():
|
||||
input = cast(input, tl.int64, builder)
|
||||
other = builder.get_int64(0)
|
||||
if src_ty.is_bool():
|
||||
other = builder.create_splat(other, src_ty.get_block_shapes())
|
||||
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
|
||||
|
||||
assert False, f'cannot cast {input} to {dst_ty}'
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
@@ -176,6 +176,9 @@ class JITFunction(KernelInterface):
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
triton.language.float16: 'fp16',
|
||||
triton.language.bfloat16: 'bf16',
|
||||
triton.language.float32: 'fp32',
|
||||
}[key]
|
||||
return f'*{ty}'
|
||||
if key is None:
|
||||
|
@@ -6,8 +6,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
// CHECK: !tt.ptr<f32> -> i64
|
||||
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
||||
// CHECK: f32 -> f16
|
||||
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
|
||||
// CHECK: f32 to f16
|
||||
%2 = arith.truncf %scalar_f32 : f32 to f16
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
@@ -18,8 +18,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
// CHECK: tensor<f32> -> tensor<f16>
|
||||
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
|
||||
// CHECK: tensor<f32> to tensor<f16>
|
||||
%5 = arith.truncf %tensor_f32_0d : tensor<f32> to tensor<f16>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
@@ -30,8 +30,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
// CHECK: tensor<16xf32> -> tensor<16xf16>
|
||||
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
|
||||
// CHECK: tensor<16xf32> to tensor<16xf16>
|
||||
%8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -125,15 +125,21 @@ TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
|
||||
PTXBuilder builder;
|
||||
const char *ptxCode =
|
||||
".param .b64 param0;\n" // prepare param0 (format string)
|
||||
"st.param.b64 [param0], %0;\n";
|
||||
"st.param.b64 [param0], %0;\n"
|
||||
"st.param.b64 [param0], %1;\n"
|
||||
"st.param.b64 [param0], %2;\n";
|
||||
|
||||
auto &ptxSnippet = *builder.create(ptxCode);
|
||||
auto *opr = builder.newOperand(v[0], "r");
|
||||
ptxSnippet({opr}, true);
|
||||
auto *opr0 = builder.newOperand(v[0], "r");
|
||||
auto *opr1 = builder.newOperand(v[1], "r");
|
||||
auto *opr2 = builder.newOperand(v[2], "r");
|
||||
ptxSnippet({opr1, opr2, opr0}, true);
|
||||
|
||||
EXPECT_EQ(builder.dump(), ptxCode);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs().size(), 1);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs()[0], v[1]);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs()[1], v[2]);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs()[2], v[0]);
|
||||
ASSERT_EQ(builder.getAllMLIRArgs().size(), 3);
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
Reference in New Issue
Block a user