From 57fd1864a78144446491edb3b384245566ba16a1 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 10 Nov 2022 15:53:06 +0800 Subject: [PATCH] [Triton-MLIR] Support FP8 (#864) Co-authored-by: Superjomn --- .gitignore | 14 +- .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 22 +- include/triton/Dialect/Triton/IR/TritonOps.td | 14 +- .../triton/Dialect/Triton/IR/TritonTypes.td | 3 +- .../TritonGPUToLLVM/PtxAsmFormat.cpp | 12 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 378 ++++++++++++++++-- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 1 + lib/Dialect/Triton/IR/Ops.cpp | 23 ++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 +- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 5 +- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 2 +- python/src/triton.cc | 15 +- python/tests/test_core.py | 140 +++---- python/triton/language/core.py | 8 + python/triton/language/semantic.py | 59 +-- python/triton/runtime/jit.py | 3 + test/Conversion/triton_ops.mlir | 12 +- .../TritonGPUToLLVM/PtxAsmFormatTest.cpp | 16 +- 18 files changed, 571 insertions(+), 160 deletions(-) diff --git a/.gitignore b/.gitignore index 6f66195eb..30a8cf52f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,20 @@ +# Triton builds build/ -__pycache__ -.pytest_cache - +# Triton Python module builds python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so +# Python caches +__pycache__ +.pytest_cache + +# VS Code project files .vscode .vs + +# JetBrains project files +.idea +cmake-build-* diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 2765f611d..e28797d66 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -22,8 +22,8 @@ struct PTXInstrExecution; // PTXBuilder helps to manage a PTX asm program consists of one or multiple // instructions. // -// A helper for building a ASM program, the objective of PTXBuilder is to give a -// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. +// A helper for building an ASM program, the objective of PTXBuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. // Currently, several factors are introduced to reduce the need for mixing // string and C++ if-else code. // @@ -147,7 +147,7 @@ struct PTXBuilder { Operand *newOperand(StringRef constraint); // Create a constant integer operand. - Operand *newConstantOperand(int v); + Operand *newConstantOperand(int64_t v); // Create a constant operand with explicit code specified. Operand *newConstantOperand(const std::string &v); @@ -172,6 +172,22 @@ private: return argArchive.back().get(); } + // Make the oprands in argArchive follow the provided \param order. + void reorderArgArchive(ArrayRef order) { + assert(order.size() == argArchive.size()); + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but + // it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are + // determined by PTX code snippet passed from external. + sort(argArchive.begin(), argArchive.end(), + [&](std::unique_ptr &a, std::unique_ptr &b) { + auto ida = std::find(order.begin(), order.end(), a.get()); + auto idb = std::find(order.begin(), order.end(), b.get()); + assert(ida != order.end()); + assert(idb != order.end()); + return ida < idb; + }); + } + friend struct PTXInstr; friend struct PTXInstrCommon; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index f22e8b742..d0981ce8f 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -10,6 +10,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; @@ -72,17 +73,16 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape, // TODO: Add verifier } -def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - NoSideEffect, - /*DeclareOpInterfaceMethods*/]> { +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Floating point casting for custom types"; let description = [{ - Floating point casting for custom types (F8, BF8). + Floating point casting for custom types (F8). - F8 <-> BF8, FP16, FP32 - BF8 <-> F8, FP16, FP32 + F8 <-> FP16, BF16, FP32, FP64 }]; let arguments = (ins TT_FloatLike:$from); diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 81184c91c..66d2a7b9a 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -14,9 +14,8 @@ class TritonTypeDef // Floating-point Type def F8 : TritonTypeDef<"Float8", "f8">; -def BF8 : TritonTypeDef<"BFloat8", "bf8">; -def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 55ac2b73a..93a1b0372 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -45,7 +45,7 @@ PTXBuilder::Operand *PTXBuilder::newConstantOperand(const std::string &v) { return argArchive.back().get(); } -PTXBuilder::Operand *PTXBuilder::newConstantOperand(int v) { +PTXBuilder::Operand *PTXBuilder::newConstantOperand(int64_t v) { std::stringstream ss; ss << "0x" << std::hex << v; return newConstantOperand(ss.str()); @@ -130,8 +130,18 @@ std::string PTXBuilder::dump() const { PTXInstrExecution &PTXInstrCommon::call(ArrayRef oprs, bool onlyAttachMLIRArgs) { + if (onlyAttachMLIRArgs) { + // Nearly impossible to make the $0,$1 in two PTX code snippets to point to + // the same MLIR values in onlyAttachMLIRArgs mode. + assert(builder->executions.empty() && + "builder can only hold a single execution when onlyAttachMIIRArgs " + "is true."); + builder->reorderArgArchive(oprs); + } + builder->executions.emplace_back( std::make_unique(this, oprs, onlyAttachMLIRArgs)); + return *builder->executions.back(); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 4df17be02..b34715c99 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -64,9 +64,8 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { rewriter.getF32FloatAttr(v)); } -// Create a index type constant. +// Create an index type constant. static Value createIndexConstant(OpBuilder &builder, Location loc, - TypeConverter *converter, int64_t value) { Type ty = converter->convertType(builder.getIndexType()); return builder.create(loc, ty, @@ -127,6 +126,7 @@ void llPrintf(StringRef msg, ValueRange args, #define i32_ty rewriter.getIntegerType(32) #define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() #define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() #define f64_ty rewriter.getF64Type() @@ -339,7 +339,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals, } Value llvmStruct = rewriter.create(loc, structType); - for (auto v : llvm::enumerate(resultVals)) { + for (const auto& v : llvm::enumerate(resultVals)) { + assert(v.value() && "can not insert null values"); llvmStruct = insert_val(structType, llvmStruct, v.value(), rewriter.getI64ArrayAttr(v.index())); } @@ -699,7 +700,7 @@ public: // [elemsPerThread X rank] index matrix. // TODO: [goostavz] Double confirm the redundant indices calculations will // be eliminated in the consequent MLIR/LLVM optimization. We might - // implement a indiceCache if necessary. + // implement an indexCache if necessary. SmallVector> emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, const BlockedEncodingAttr &blockedLayout, @@ -953,7 +954,8 @@ struct LoadOpConversion // Determine the vectorization size Type valueTy = op.getResult().getType(); - Type valueElemTy = getElementTypeOrSelf(valueTy); + Type valueElemTy = typeConverter->convertType( + getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); if (llMask) @@ -1147,7 +1149,8 @@ struct StoreOpConversion MLIRContext *ctx = rewriter.getContext(); auto valueTy = value.getType(); - Type valueElemTy = getElementTypeOrSelf(valueTy); + Type valueElemTy = typeConverter->convertType( + getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getElemsPerThread(ptr.getType()); @@ -1734,8 +1737,8 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // We cannot directly - // rewriter.replaceOp(op, adaptor.src()); + // We cannot directly run + // `rewriter.replaceOp(op, adaptor.src())` // due to MLIR's restrictions Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); @@ -2096,6 +2099,330 @@ struct ExtractSliceOpConversion } }; +struct FpToFpOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern; + + static SmallVector convertFp8x4ToFp16x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto ctx = rewriter.getContext(); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value fp8x4Vec = undef(fp8x4VecTy); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3)); + fp8x4Vec = bitcast(fp8x4Vec, i32_ty); + + PTXBuilder builder; + auto *ptxAsm = + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5040; \n" + "prmt.b32 a1, 0, $2, 0x7060; \n" + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" + "shr.b32 b0, b0, 1; \n" + "shr.b32 b1, b1, 1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" + "}"; + auto &call = *builder.create(ptxAsm); + + auto *o0 = builder.newOperand("=r"); + auto *o1 = builder.newOperand("=r"); + auto *i = builder.newOperand(fp8x4Vec, "r"); + call({o0, o1, i}, /* onlyAttachMLIRArgs */ true); + + auto fp16x2VecTy = vec_ty(f16_ty, 2); + auto fp16x2x2StructTy = + struct_ty(SmallVector{fp16x2VecTy, fp16x2VecTy}); + auto fp16x2x2Struct = + builder.launch(rewriter, loc, fp16x2x2StructTy, false); + auto fp16x2Vec0 = extract_val(fp16x2VecTy, fp16x2x2Struct, + rewriter.getI32ArrayAttr({0})); + auto fp16x2Vec1 = extract_val(fp16x2VecTy, fp16x2x2Struct, + rewriter.getI32ArrayAttr({1})); + return { + extract_element(f16_ty, fp16x2Vec0, i32_val(0)), + extract_element(f16_ty, fp16x2Vec0, i32_val(1)), + extract_element(f16_ty, fp16x2Vec1, i32_val(0)), + extract_element(f16_ty, fp16x2Vec1, i32_val(1)) + }; + } + + static SmallVector convertFp16x4ToFp8x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto ctx = rewriter.getContext(); + auto fp16x2VecTy = vec_ty(f16_ty, 2); + Value fp16x2Vec0 = undef(fp16x2VecTy); + Value fp16x2Vec1 = undef(fp16x2VecTy); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0)); + fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0)); + fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1)); + fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty); + fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty); + + PTXBuilder builder; + auto *ptxAsm = + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "shl.b32 a0, $1, 1; \n" + "shl.b32 a1, $2, 1; \n" + "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" + "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" + "add.u32 a0, a0, 0x00800080; \n" + "add.u32 a1, a1, 0x00800080; \n" + "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" + "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" + "prmt.b32 $0, b0, b1, 0x7531; \n" + "}"; + auto &call = *builder.create(ptxAsm); + + auto *o = builder.newOperand("=r"); + auto *i0 = builder.newOperand(fp16x2Vec0, "r"); + auto *i1 = builder.newOperand(fp16x2Vec1, "r"); + call({o, i0, i1}, /* onlyAttachMLIRArgs */ true); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); + return { + extract_element(i8_ty, fp8x4Vec, i32_val(0)), + extract_element(i8_ty, fp8x4Vec, i32_val(1)), + extract_element(i8_ty, fp8x4Vec, i32_val(2)), + extract_element(i8_ty, fp8x4Vec, i32_val(3)) + }; + } + + static SmallVector convertFp8x4ToBf16x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto ctx = rewriter.getContext(); + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value fp8x4Vec = undef(fp8x4VecTy); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3)); + fp8x4Vec = bitcast(fp8x4Vec, i32_ty); + + PTXBuilder builder; + auto *ptxAsm = + "{ \n" + ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5040; \n" + "prmt.b32 a1, 0, $2, 0x7060; \n" + "and.b32 sign0, a0, 0x80008000; \n" + "and.b32 sign1, a1, 0x80008000; \n" + "and.b32 nosign0, a0, 0x7fff7fff; \n" + "and.b32 nosign1, a1, 0x7fff7fff; \n" + "shr.b32 nosign0, nosign0, 4; \n" + "shr.b32 nosign1, nosign1, 4; \n" + "add.u32 nosign0, nosign0, 0x38003800; \n" + "add.u32 nosign1, nosign1, 0x38003800; \n" + "or.b32 $0, sign0, nosign0; \n" + "or.b32 $1, sign1, nosign1; \n" + "}"; + auto &call = *builder.create(ptxAsm); + + auto *o0 = builder.newOperand("=r"); + auto *o1 = builder.newOperand("=r"); + auto *i = builder.newOperand(fp8x4Vec, "r"); + call({o0, o1, i}, /* onlyAttachMLIRArgs */ true); + + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + auto bf16x2x2StructTy = + struct_ty(SmallVector{bf16x2VecTy, bf16x2VecTy}); + auto bf16x2x2Struct = + builder.launch(rewriter, loc, bf16x2x2StructTy, false); + auto bf16x2Vec0 = extract_val(bf16x2VecTy, bf16x2x2Struct, + rewriter.getI32ArrayAttr({0})); + auto bf16x2Vec1 = extract_val(bf16x2VecTy, bf16x2x2Struct, + rewriter.getI32ArrayAttr({1})); + return { + extract_element(bf16_ty, bf16x2Vec0, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec0, i32_val(1)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(1)) + }; + } + + static SmallVector convertBf16x4ToFp8x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto ctx = rewriter.getContext(); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); + Value bf16x2Vec0 = undef(bf16x2VecTy); + Value bf16x2Vec1 = undef(bf16x2VecTy); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0)); + bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0)); + bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1)); + bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty); + bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty); + + PTXBuilder builder; + auto *ptxAsm = + "{ \n" + ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n" + ".reg .u32 fp8_min, fp8_max, rn_, zero; \n" + "mov.u32 fp8_min, 0x38003800; \n" + "mov.u32 fp8_max, 0x3ff03ff0; \n" + "mov.u32 rn_, 0x80008; \n" + "mov.u32 zero, 0; \n" + "and.b32 sign0, $1, 0x80008000; \n" + "and.b32 sign1, $2, 0x80008000; \n" + "prmt.b32 sign, sign0, sign1, 0x7531; \n" + "and.b32 nosign0, $1, 0x7fff7fff; \n" + "and.b32 nosign1, $2, 0x7fff7fff; \n" + ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n" + "and.b32 nosign_0_0, nosign0, 0xffff0000; \n" + "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n" + "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n" + "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n" + "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n" + "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n" + "or.b32 nosign0, nosign_0_0, nosign_0_1; \n" + "and.b32 nosign_1_0, nosign1, 0xffff0000; \n" + "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n" + "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n" + "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n" + "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n" + "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n" + "or.b32 nosign1, nosign_1_0, nosign_1_1; \n" + "add.u32 nosign0, nosign0, rn_; \n" + "add.u32 nosign1, nosign1, rn_; \n" + "sub.u32 nosign0, nosign0, 0x38003800; \n" + "sub.u32 nosign1, nosign1, 0x38003800; \n" + "shr.u32 nosign0, nosign0, 4; \n" + "shr.u32 nosign1, nosign1, 4; \n" + "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n" + "or.b32 $0, nosign, sign; \n" + "}"; + auto &call = *builder.create(ptxAsm); + + auto *o = builder.newOperand("=r"); + auto *i0 = builder.newOperand(bf16x2Vec0, "r"); + auto *i1 = builder.newOperand(bf16x2Vec1, "r"); + call({o, i0, i1}, /* onlyAttachMLIRArgs */ true); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false); + return { + extract_element(i8_ty, fp8x4Vec, i32_val(0)), + extract_element(i8_ty, fp8x4Vec, i32_val(1)), + extract_element(i8_ty, fp8x4Vec, i32_val(2)), + extract_element(i8_ty, fp8x4Vec, i32_val(3)) + }; + } + + static SmallVector convertFp8x4ToFp32x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); + return { + rewriter.create(loc, f32_ty, fp16Values[0]), + rewriter.create(loc, f32_ty, fp16Values[1]), + rewriter.create(loc, f32_ty, fp16Values[2]), + rewriter.create(loc, f32_ty, fp16Values[3]) + }; + } + + static SmallVector convertFp32x4ToFp8x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto c0 = rewriter.create(loc, f16_ty, v0); + auto c1 = rewriter.create(loc, f16_ty, v1); + auto c2 = rewriter.create(loc, f16_ty, v2); + auto c3 = rewriter.create(loc, f16_ty, v3); + return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); + } + + static SmallVector convertFp8x4ToFp64x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3); + return { + rewriter.create(loc, f64_ty, fp16Values[0]), + rewriter.create(loc, f64_ty, fp16Values[1]), + rewriter.create(loc, f64_ty, fp16Values[2]), + rewriter.create(loc, f64_ty, fp16Values[3]) + }; + } + + static SmallVector convertFp64x4ToFp8x4( + Location loc, ConversionPatternRewriter &rewriter, + const Value& v0, const Value& v1, const Value& v2, const Value& v3) { + auto c0 = rewriter.create(loc, f16_ty, v0); + auto c1 = rewriter.create(loc, f16_ty, v1); + auto c2 = rewriter.create(loc, f16_ty, v2); + auto c3 = rewriter.create(loc, f16_ty, v3); + return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3); + } + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcTensorType = op.from().getType().cast(); + auto dstTensorType = op.result().getType().cast(); + auto srcEltType = srcTensorType.getElementType(); + auto dstEltType = dstTensorType.getElementType(); + assert(srcEltType.isa() || + dstEltType.isa()); + auto convertedDstTensorType = + this->getTypeConverter()->convertType(dstTensorType); + auto convertedDstEleType = + this->getTypeConverter()->convertType(dstEltType); + + // Select convertor + std::function(Location, ConversionPatternRewriter&, + const Value&, const Value&, + const Value&, const Value&)> convertor; + if (srcEltType.isa() && dstEltType.isF16()) { + convertor = convertFp8x4ToFp16x4; + } else if (srcEltType.isF16() && dstEltType.isa()) { + convertor = convertFp16x4ToFp8x4; + } else if (srcEltType.isa() && dstEltType.isBF16()) { + convertor = convertFp8x4ToBf16x4; + } else if (srcEltType.isBF16() && dstEltType.isa()) { + convertor = convertBf16x4ToFp8x4; + } else if (srcEltType.isa() && dstEltType.isF32()) { + convertor = convertFp8x4ToFp32x4; + } else if (srcEltType.isF32() && dstEltType.isa()) { + convertor = convertFp32x4ToFp8x4; + } else if (srcEltType.isa() && dstEltType.isF64()) { + convertor = convertFp8x4ToFp64x4; + } else if (srcEltType.isF64() && dstEltType.isa()) { + convertor = convertFp64x4ToFp8x4; + } else { + assert(false && "unsupported type casting"); + } + + // Vectorized casting + auto loc = op->getLoc(); + auto elems = getElemsPerThread(dstTensorType); + assert(elems % 4 == 0 && + "FP8 casting only support tensors with 4-aligned sizes"); + auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter); + SmallVector resultVals; + for (size_t i = 0; i < elems; i += 4) { + auto converted = convertor(loc, rewriter, + elements[i], elements[i + 1], + elements[i + 2], elements[i + 3]); + resultVals.append(converted); + } + assert(resultVals.size() == elems); + auto result = getStructFromElements(loc, resultVals, rewriter, + convertedDstTensorType); + rewriter.replaceOp(op, result); + return success(); + } +}; + // A CRTP style of base class. template class ElementwiseOpConversionBase @@ -2309,7 +2636,7 @@ private: Value smemBase) const; // blocked/mma -> blocked/mma. - // Data padding in shared memory to avoid bank confict. + // Data padding in shared memory to avoid bank conflict. LogicalResult lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -3265,7 +3592,7 @@ struct DotOpMmaV1ConversionHelper { int NK = shape[1]; unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); - // NOTE We cound't get the vec from the shared layout. + // NOTE: We couldn't get the vec from the shared layout. // int vecA = sharedLayout.getVec(); // TODO[Superjomn]: Consider the case when vecA > 4 bool vecGt4 = false; @@ -3283,7 +3610,7 @@ struct DotOpMmaV1ConversionHelper { SmallVector fpw({2, 2, 1}); SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 - // NOTE We cound't get the vec from the shared layout. + // NOTE: We couldn't get the vec from the shared layout. // int vecB = sharedLayout.getVec(); // TODO[Superjomn]: Consider the case when vecA > 4 bool vecGt4 = false; @@ -3387,7 +3714,7 @@ struct DotOpMmaV2ConversionHelper { return Type{}; } - // The type of a matrix that loaded by either a ldmatrix or composed lds. + // The type of matrix that loaded by either a ldmatrix or composed lds. Type getMatType() const { Type fp32Ty = type::f32Ty(ctx); Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); @@ -3583,7 +3910,7 @@ private: "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, }; - // vector length per ldmatrix (16*8/elelment_size_in_bits) + // vector length per ldmatrix (16*8/element_size_in_bits) inline static const std::map mmaInstrVec = { {TensorCoreType::FP32_FP16_FP16_FP32, 8}, {TensorCoreType::FP32_BF16_BF16_FP32, 8}, @@ -3723,7 +4050,7 @@ struct MMA16816ConversionHelper { // load from smem loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, - 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, + 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); } else if (aTensorTy.getEncoding().isa()) { // load from registers, used in gemm fuse @@ -3754,7 +4081,7 @@ struct MMA16816ConversionHelper { auto loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, - 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, + 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { @@ -4713,14 +5040,15 @@ public: addConversion([&](RankedTensorType type) -> llvm::Optional { return convertTritonTensorType(type); }); - // internally store bfloat16 as int16 - addConversion([&](BFloat16Type type) -> llvm::Optional { - return IntegerType::get(type.getContext(), 16); + // Internally store float8 as int8 + addConversion([&](triton::Float8Type type) -> llvm::Optional { + return IntegerType::get(type.getContext(), 8); }); } Type convertTritonPointerType(triton::PointerType type) { - return LLVM::LLVMPointerType::get(type.getPointeeType(), + // Recursively translate pointee type + return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()), type.getAddressSpace()); } @@ -4753,7 +5081,7 @@ public: auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type); size_t fcSize = 4 * repM * repN; return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fcSize, type.getElementType())); + ctx, SmallVector(fcSize, convertType(type.getElementType()))); } if (mmaLayout.getVersion() == 1) { @@ -4762,7 +5090,7 @@ public: int repN = helper.getRepN(shape[1]); int elems = 8 * repM * repN; return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, type.getElementType())); + ctx, SmallVector(elems, convertType(type.getElementType()))); } llvm::errs() @@ -4773,7 +5101,7 @@ public: layout.dyn_cast_or_null()) { auto mmaLayout = dot_op_layout.getParent().cast(); auto wpt = mmaLayout.getWarpsPerCTA(); - Type elemTy = type.getElementType(); + Type elemTy = convertType(type.getElementType()); auto vecSize = 1; if (elemTy.getIntOrFloatBitWidth() == 16) { vecSize = 2; @@ -5324,6 +5652,8 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) #undef POPULATE_UNARY_OP + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); @@ -5369,12 +5699,12 @@ public: int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); // step 1: Allocate shared memories and insert barriers - // setp 2: Convert SCF to CFG + // step 2: Convert SCF to CFG // step 3: Convert FuncOp to LLVMFuncOp via partial conversion // step 4: Convert the rest of ops via partial conversion // The reason for putting step 1 before step 2 is that the membar analysis // currently only supports SCF but not CFG. - // The reason for a seperation between 1/4 is that, step 3 is out of + // The reason for a separation between 1/4 is that, step 3 is out of // the scope of Dialect Conversion, thus we need to make sure the smem // is not revised during the conversion of step 4. Allocation allocation(mod); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 32448f46d..4f4d56847 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -371,6 +371,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, patterns.add< // TODO: view should have custom pattern that views the layout TritonGenericPattern, TritonGenericPattern, + TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonBroadcastPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3982c6b58..5c85e6b86 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -124,6 +124,29 @@ void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) { namespace mlir { namespace triton { +//-- FpToFpOp -- +bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs, + ::mlir::TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + auto srcEltType = inputs.front(); + auto dstEltType = outputs.front(); + auto srcTensorType = srcEltType.dyn_cast(); + auto dstTensorType = dstEltType.dyn_cast(); + if (srcTensorType && dstTensorType) { + srcEltType = srcTensorType.getElementType(); + dstEltType = dstTensorType.getElementType(); + } + // Check whether fp8 <=> fp16, bf16, f32, f64 + // Make `srcEltType` always the fp8 side + if (dstEltType.dyn_cast()) + std::swap(srcEltType, dstEltType); + if (!srcEltType.dyn_cast()) + return false; + return dstEltType.isF16() || dstEltType.isBF16() || + dstEltType.isF32() || dstEltType.isF64(); +} + //-- StoreOp -- void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value) { diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9d3a0bcd1..e91ed6d69 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -44,7 +44,9 @@ namespace gpu { // TODO: Inheritation of layout attributes unsigned getElemsPerThread(Type type) { - if (type.isIntOrIndexOrFloat() || type.isa()) + if (type.isIntOrIndexOrFloat() || + type.isa() || + type.isa()) return 1; auto tensorType = type.cast(); auto layout = tensorType.getEncoding(); diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 85c082c24..d1b62f75c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -32,7 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { // Thread tile size depends on memory alignment SmallVector sizePerThread(rank, 1); PointerType ptrType = origType.getElementType().cast(); - unsigned numBits = ptrType.getPointeeType().getIntOrFloatBitWidth(); + auto pointeeType = ptrType.getPointeeType(); + unsigned numBits = + pointeeType.isa() ? + 8 : pointeeType.getIntOrFloatBitWidth(); unsigned maxMultiple = info.getDivisibility(order[0]); unsigned maxContig = info.getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 645aaff43..aa73b4f81 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -140,7 +140,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); pm.addPass(createConvertTritonGPUToLLVMPass()); - // Conanicalize to eliminate the remaining UnrealizedConversionCastOp + // Canonicalize to eliminate the remaining UnrealizedConversionCastOp pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability. pm.addPass(mlir::createSymbolDCEPass()); diff --git a/python/src/triton.cc b/python/src/triton.cc index b4b6b068d..c04cc9da6 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -493,10 +493,6 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self) -> mlir::Type { return self.getType(); }) - .def("get_bf8_ty", - [](mlir::OpBuilder &self) -> mlir::Type { - return self.getType(); - }) .def( "get_half_ty", [](mlir::OpBuilder &self) -> mlir::Type { return self.getF16Type(); }) @@ -616,14 +612,20 @@ void init_triton_ir(py::module &&m) { }) // Cast instructions + // Conversions for custom FP types (FP8) + .def("create_fp_to_fp", + [](mlir::OpBuilder &self, mlir::Value &src, + mlir::Type &dstType) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, dstType, src); + }) + // Conversions for standard LLVM builtin types .def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); return self.create(loc, dstType, src); }) - // .def("create_cast", &ir::builder::create_cast) - // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) .def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { @@ -697,7 +699,6 @@ void init_triton_ir(py::module &&m) { return self.create(loc, input, self.getI32Type()); }) - .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 77068fa44..7b994c3cb 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -780,88 +780,88 @@ def test_store_bool(): assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all() -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# def test_f8_xf16_roundtrip(dtype): -# """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" -# check_type_supported(dtype) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_f8_xf16_roundtrip(dtype): + """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" + check_type_supported(dtype) -# @triton.jit -# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): -# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask = offsets < n_elements -# input = tl.load(input_ptr + offsets, mask=mask) -# output = input -# tl.store(output_ptr + offsets, output, mask=mask) + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) -# f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') -# f8 = triton.reinterpret(f8_tensor, tl.float8) -# n_elements = f8_tensor.numel() -# xf16 = torch.empty_like(f8_tensor, dtype=dtype) -# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) -# copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024) + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') + f8 = triton.reinterpret(f8_tensor, tl.float8) + n_elements = f8_tensor.numel() + xf16 = torch.empty_like(f8_tensor, dtype=dtype) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024) -# f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8) -# f8_output = triton.reinterpret(f8_output_tensor, tl.float8) -# copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024) + f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024) -# assert torch.all(f8_tensor == f8_output_tensor) + assert torch.all(f8_tensor == f8_output_tensor) -# def test_f16_to_f8_rounding(): -# """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute -# error is the minimum over all float8. -# Or the same explanation a bit mathier: -# for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" -# @triton.jit -# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): -# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) -# mask = offsets < n_elements -# input = tl.load(input_ptr + offsets, mask=mask) -# output = input -# tl.store(output_ptr + offsets, output, mask=mask) +def test_f16_to_f8_rounding(): + """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute + error is the minimum over all float8. + Or the same explanation a bit mathier: + for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) -# # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view -# f16_input_np = ( -# np.array( -# range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16, -# ) -# .view(np.float16) -# ) -# f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda') -# n_elements = f16_input.numel() -# f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8) -# f8_output = triton.reinterpret(f8_output_tensor, tl.float8) -# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) -# copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024) + # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view + f16_input_np = ( + np.array( + range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16, + ) + .view(np.float16) + ) + f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda') + n_elements = f16_input.numel() + f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8) + f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024) -# f16_output = torch.empty_like(f16_input, dtype=torch.float16) -# copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024) + f16_output = torch.empty_like(f16_input, dtype=torch.float16) + copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024) -# abs_error = torch.abs(f16_input - f16_output) + abs_error = torch.abs(f16_input - f16_output) -# all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda') -# all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8) -# all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16) -# copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024) + all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda') + all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8) + all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16) + copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024) -# all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[ -# torch.isfinite(all_f8_vals_in_f16) -# ] + all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[ + torch.isfinite(all_f8_vals_in_f16) + ] -# min_error = torch.min( -# torch.abs( -# f16_input.reshape((-1, 1)) -# - all_finite_f8_vals_in_f16.reshape((1, -1)) -# ), -# dim=1, -# )[0] -# # 1.9375 is float8 max -# mismatch = torch.logical_and( -# abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375) -# ) -# assert torch.all( -# torch.logical_not(mismatch) -# ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}" + min_error = torch.min( + torch.abs( + f16_input.reshape((-1, 1)) + - all_finite_f8_vals_in_f16.reshape((1, -1)) + ), + dim=1, + )[0] + # 1.9375 is float8 max + mismatch = torch.logical_and( + abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375) + ) + assert torch.all( + torch.logical_not(mismatch) + ), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}" # # --------------- diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 94da20f9d..de40ec37d 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -48,6 +48,8 @@ class dtype: SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] + CUSTOMIZED_FP_TYPES = ['fp8'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] OTHER_TYPES = ['void'] class SIGNEDNESS(Enum): @@ -129,6 +131,12 @@ class dtype: def is_floating(self): return self.name in dtype.FP_TYPES + def is_customized_floating(self): + return self.name in dtype.CUSTOMIZED_FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + def is_int_signed(self): return self.name in dtype.SINT_TYPES diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b56e40c35..846d4ee03 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -613,39 +613,45 @@ def cast(input: tl.tensor, dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) if src_ty == dst_ty: return input + src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - # fp8 <=> bf16/fp16 - if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8(): - return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), + + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()): + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) - if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()): - return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), - dst_ty) - # bf16 <=> (not fp32) - if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ - (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): + + # Casting types of the same bit width: fp16 <=> bf16 + if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \ + (src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()): return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) - # FP Truncation + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 truncate_fp = src_sca_ty.is_floating() and \ dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth if truncate_fp: return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) - # FP Extension + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 ext_fp = src_sca_ty.is_floating() and \ dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth if ext_fp: return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) - # Int cast + # Casting between integer types if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() @@ -658,8 +664,8 @@ def cast(input: tl.tensor, dst_ty.to_ir(builder), sign_extend), dst_ty) - # Float to Int - if src_sca_ty.is_floating() and dst_sca_ty.is_int(): + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): if dst_sca_ty.is_bool(): ty = input.dtype.to_ir(builder) _0 = tl.tensor(builder.get_null_value(ty), input.dtype) @@ -673,8 +679,8 @@ def cast(input: tl.tensor, dst_ty.to_ir(builder)), dst_ty) - # int => float - if src_sca_ty.is_int() and dst_sca_ty.is_floating(): + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), @@ -684,7 +690,7 @@ def cast(input: tl.tensor, dst_ty.to_ir(builder)), dst_ty) - # ptr => int + # Casting pointer types to integer types if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: @@ -695,19 +701,14 @@ def cast(input: tl.tensor, tl.tensor(builder.get_int64(0), tl.int64), builder) - if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - # Ptr . Ptr + + # Casting pointer types to pointer types if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - # * . Bool - if dst_sca_ty.is_bool(): - if src_sca_ty.is_ptr(): - input = cast(input, tl.int64, builder) - other = builder.get_int64(0) - if src_ty.is_bool(): - other = builder.create_splat(other, src_ty.get_block_shapes()) - return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) + assert False, f'cannot cast {input} to {dst_ty}' # ===----------------------------------------------------------------------===// diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4539d3c08..e23b0279f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -176,6 +176,9 @@ class JITFunction(KernelInterface): triton.language.uint32: 'u32', triton.language.uint64: 'u64', triton.language.float8: 'fp8', + triton.language.float16: 'fp16', + triton.language.bfloat16: 'bf16', + triton.language.float32: 'fp32', }[key] return f'*{ty}' if key is None: diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 55cb9049f..e9d484887 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -6,8 +6,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr // CHECK: !tt.ptr -> i64 %1 = tt.ptr_to_int %scalar_ptr : !tt.ptr -> i64 - // CHECK: f32 -> f16 - %2 = tt.fp_to_fp %scalar_f32 : f32 -> f16 + // CHECK: f32 to f16 + %2 = arith.truncf %scalar_f32 : f32 to f16 // 0D tensor -> 0D tensor %tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor> @@ -18,8 +18,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { %3 = tt.int_to_ptr %tensor_i64_0d : tensor -> tensor> // CHECK: tensor> -> tensor %4 = tt.ptr_to_int %tensor_ptr_0d : tensor> -> tensor - // CHECK: tensor -> tensor - %5 = tt.fp_to_fp %tensor_f32_0d : tensor -> tensor + // CHECK: tensor to tensor + %5 = arith.truncf %tensor_f32_0d : tensor to tensor // 1D tensor -> 1D tensor %tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr) -> tensor<16x!tt.ptr> @@ -30,8 +30,8 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { %6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr> // CHECK: tensor<16x!tt.ptr> -> tensor<16xi64> %7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr> -> tensor<16xi64> - // CHECK: tensor<16xf32> -> tensor<16xf16> - %8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16> + // CHECK: tensor<16xf32> to tensor<16xf16> + %8 = arith.truncf %tensor_f32_1d : tensor<16xf32> to tensor<16xf16> return } diff --git a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp index 1c3f3fb27..575c31ee4 100644 --- a/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/PtxAsmFormatTest.cpp @@ -125,15 +125,21 @@ TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) { PTXBuilder builder; const char *ptxCode = ".param .b64 param0;\n" // prepare param0 (format string) - "st.param.b64 [param0], %0;\n"; + "st.param.b64 [param0], %0;\n" + "st.param.b64 [param0], %1;\n" + "st.param.b64 [param0], %2;\n"; auto &ptxSnippet = *builder.create(ptxCode); - auto *opr = builder.newOperand(v[0], "r"); - ptxSnippet({opr}, true); + auto *opr0 = builder.newOperand(v[0], "r"); + auto *opr1 = builder.newOperand(v[1], "r"); + auto *opr2 = builder.newOperand(v[2], "r"); + ptxSnippet({opr1, opr2, opr0}, true); EXPECT_EQ(builder.dump(), ptxCode); - ASSERT_EQ(builder.getAllMLIRArgs()[0], v[0]); - ASSERT_EQ(builder.getAllMLIRArgs().size(), 1); + ASSERT_EQ(builder.getAllMLIRArgs()[0], v[1]); + ASSERT_EQ(builder.getAllMLIRArgs()[1], v[2]); + ASSERT_EQ(builder.getAllMLIRArgs()[2], v[0]); + ASSERT_EQ(builder.getAllMLIRArgs().size(), 3); } } // namespace triton