From 7355efa745d6308a185ed91d22afa0bc72ed7334 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 1 May 2021 14:34:33 -0400 Subject: [PATCH] [LANG] Preliminary FP8 support (#96) --- include/triton/codegen/selection/generator.h | 4 + include/triton/ir/context_impl.h | 8 +- include/triton/ir/type.h | 32 +++--- lib/codegen/selection/generator.cc | 102 +++++++++++++++++++ lib/ir/context.cc | 3 + lib/ir/type.cc | 7 +- python/src/triton.cc | 2 + python/triton/__init__.py | 2 +- python/triton/code_gen.py | 60 +++++++---- python/triton/language.py | 2 + 10 files changed, 182 insertions(+), 40 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 60c3933ab..7c3b51ae5 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -102,6 +102,10 @@ public: void visit_getelementptr_inst(ir::getelementptr_inst*); void visit_icmp_inst(ir::icmp_inst*); void visit_fcmp_inst(ir::fcmp_inst*); + std::tuple fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3); + std::tuple fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); + std::tuple fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); + std::tuple fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); void visit_cast_inst(ir::cast_inst*); void visit_return_inst(ir::return_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 3db225a37..ba175a434 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -22,9 +22,11 @@ public: context_impl(context &ctx); public: - // primitive types - type void_ty, label_ty, half_ty, float_ty, double_ty; - // derived types + // non-numeric types + type void_ty, label_ty; + // floating point types + type fp8_ty, half_ty, float_ty, double_ty; + // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; // Pointer types std::map, pointer_type*> ptr_tys; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 27e6acef0..804566362 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -29,21 +29,22 @@ public: enum id_t { // primitive types VoidTyID = 0, ///< 0: type with no size - HalfTyID, ///< 1: 16-bit floating point type - FloatTyID, ///< 2: 32-bit floating point type - DoubleTyID, ///< 3: 64-bit floating point type - X86_FP80TyID, ///< 4: 80-bit floating point type (X87) - FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa) - PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC) - LabelTyID, ///< 7: Labels - MetadataTyID, ///< 8: Metadata - TokenTyID, ///< 9: Token + FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa) + HalfTyID, ///< 3: 16-bit floating point type + FloatTyID, ///< 4: 32-bit floating point type + DoubleTyID, ///< 5: 64-bit floating point type + X86_FP80TyID, ///< 6: 80-bit floating point type (X87) + FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa) + PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC) + LabelTyID, ///< 9: Labels + MetadataTyID, ///< 10: Metadata + TokenTyID, ///< 11: Token // derived types - IntegerTyID, ///< 10: Arbitrary bit width integers - FunctionTyID, ///< 11: Functions - PointerTyID, ///< 12: Pointers - StructTyID, ///< 13: Struct - BlockTyID, ///< 14: Tile + IntegerTyID, ///< 12: Arbitrary bit width integers + FunctionTyID, ///< 13: Functions + PointerTyID, ///< 14: Pointers + StructTyID, ///< 15: Struct + BlockTyID, ///< 16: Block }; public: @@ -72,6 +73,7 @@ public: // primitive predicates bool is_void_ty() const { return id_ == VoidTyID; } + bool is_fp8_ty() const { return id_ == FP8TyID; } bool is_half_ty() const { return id_ == HalfTyID; } bool is_float_ty() const { return id_ == FloatTyID; } bool is_double_ty() const { return id_ == DoubleTyID; } @@ -96,6 +98,7 @@ public: static type *get_void_ty(context &ctx); static type *get_label_ty(context &ctx); // half + static type *get_fp8_ty(context &ctx); static type *get_half_ty(context &ctx); static type *get_float_ty(context &ctx); static type *get_double_ty(context &ctx); @@ -124,6 +127,7 @@ public: std::string repr() const { switch(id_) { case VoidTyID: return "void"; + case FP8TyID: return "fp8"; case HalfTyID: return "f16"; case FloatTyID: return "f32"; case DoubleTyID: return "f64"; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 01bf51a28..f19892cd3 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -27,6 +27,7 @@ using namespace llvm; #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() #define f32_ty builder_->getFloatTy() +#define i8_ty builder_->getInt8Ty() #define i32_ty builder_->getInt32Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) @@ -60,6 +61,7 @@ using namespace llvm; #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) #define load(...) builder_->CreateLoad(__VA_ARGS__) +#define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) #define mul(...) builder_->CreateMul(__VA_ARGS__) @@ -69,6 +71,7 @@ using namespace llvm; #define select(...) builder_->CreateSelect(__VA_ARGS__) #define store(...) builder_->CreateStore(__VA_ARGS__) #define sub(...) builder_->CreateSub(__VA_ARGS__) +#define shl(...) builder_->CreateShl(__VA_ARGS__) #define udiv(...) builder_->CreateUDiv(__VA_ARGS__) #define urem(...) builder_->CreateURem(__VA_ARGS__) #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) @@ -101,6 +104,7 @@ Type *generator::cvt(ir::type *ty) { // primitive types switch(ty->get_type_id()){ case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); + case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::HalfTyID: return Type::getHalfTy(*ctx_); case ir::type::FloatTyID: return Type::getFloatTy(*ctx_); case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_); @@ -316,10 +320,108 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) { } } + +std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ + InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false), + "{ \n\t" + ".reg .b32 b<4>; \n\t" + "shl.b32 b0, $1, 4; \n\t" // shift into into upper byte + "shl.b32 b1, $2, 4; \n\t" + "shl.b32 b2, $3, 4; \n\t" + "shl.b32 b3, $4, 4; \n\t" + "lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign + "lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t" + "lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t" + "lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t" + "prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here) + "prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here) + "prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3 + "}", "=r, r, r, r, r", false); + Value *packed_ret = call(ptx, {in0, in1, in2, in3}); + Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4)); + return std::make_tuple(extract_elt(ret, (int)0), + extract_elt(ret, (int)1), + extract_elt(ret, (int)2), + extract_elt(ret, (int)3)); +} + +std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ + Value *ret0, *ret1, *ret2, *ret3; + std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3); + ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty); + ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty); + ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty); + ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + + +std::tuple generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){ + Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); + InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), + "{" + ".reg .b32 a<2>, b<2>; \n\t" + "prmt.b32 a0, 0, $2, 0x5140; \n\t" + "prmt.b32 a1, 0, $2, 0x7362; \n\t" + "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign + "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" + "shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion + "shr.b32 b1, b1, 1; \n\t" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign + "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" + "}", "=r,=r,r", false); + Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); + packed_in = insert_elt(packed_in, in0, (int)0); + packed_in = insert_elt(packed_in, in1, (int)1); + packed_in = insert_elt(packed_in, in2, (int)2); + packed_in = insert_elt(packed_in, in3, (int)3); + Value *in = bit_cast(packed_in, i32_ty); + Value *ret = call(ptx, {in}); + Value *packed_ret0 = extract_val(ret, {0}); + Value *packed_ret1 = extract_val(ret, {1}); + Value *ret0 = extract_elt(packed_ret0, (int)0); + Value *ret1 = extract_elt(packed_ret0, (int)1); + Value *ret2 = extract_elt(packed_ret1, (int)0); + Value *ret3 = extract_elt(packed_ret1, (int)1); + return std::make_tuple(ret0, ret1, ret2, ret3); +} + + /** * \brief Code Generation for `cast` */ void generator::visit_cast_inst(ir::cast_inst* x) { + // <> FP8 + ir::value *op = x->get_operand(0); + ir::type* ret_sca_ty = x->get_type()->get_scalar_ty(); + ir::type* op_sca_ty = op->get_type()->get_scalar_ty(); + if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){ + // ensure that conversions can be vectorized + int ld = layouts_->get(x)->get_order(0); + int contiguous = layouts_->get(x)->to_scanline()->nts(ld); + if(contiguous % 4 != 0) + throw std::runtime_error("unsupported fp32 -> fp8 conversion"); + auto x_idxs = idxs_.at(x); + auto op_idxs = idxs_.at(op); + // run the conversion + auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ + if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty()) + return fp8x4_to_fp16x4(a, b, c, d); + throw std::runtime_error("unsupported conversion"); + }; + for(size_t i = 0; i < x_idxs.size(); i+=4){ + std::tie(vals_[x][x_idxs[i+0]], + vals_[x][x_idxs[i+1]], + vals_[x][x_idxs[i+2]], + vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]], + vals_[op][op_idxs[i+1]], + vals_[op][op_idxs[i+2]], + vals_[op][op_idxs[i+3]]); + } + return; + } + + Type *ty = cvt(x->get_type()->get_scalar_ty()); auto cvt = [](ir::cast_op_t op){ using ll = llvm::Instruction::CastOps; diff --git a/lib/ir/context.cc b/lib/ir/context.cc index e0a6976e0..7aa79dde4 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -12,9 +12,12 @@ namespace ir{ context_impl::context_impl(context &ctx) : void_ty(ctx, type::VoidTyID), label_ty(ctx, type::LabelTyID), + // floating point + fp8_ty(ctx, type::FP8TyID), half_ty(ctx, type::HalfTyID), float_ty(ctx, type::FloatTyID), double_ty(ctx, type::DoubleTyID), + // integers int1_ty(ctx, 1), int8_ty(ctx, 8), int16_ty(ctx, 16), diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 9d985dc25..9607e7db2 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -21,6 +21,7 @@ type *type::get_scalar_ty() const { unsigned type::get_primitive_size_in_bits() const { switch (id_) { + case FP8TyID: return 8; case HalfTyID: return 16; case FloatTyID: return 32; case DoubleTyID: return 64; @@ -42,6 +43,7 @@ unsigned type::get_tile_bitwidth() const unsigned type::get_fp_mantissa_width() const { id_t id = get_scalar_ty()->id_; assert(is_floating_point_ty() && "Not a floating point type!"); + if (id == FP8TyID) return 3; if (id == HalfTyID) return 10; if (id == FloatTyID) return 23; if (id == DoubleTyID) return 53; @@ -103,7 +105,7 @@ bool type::is_integer_ty(unsigned width) const bool type::is_floating_point_ty() const -{ return is_half_ty() || is_float_ty() || is_double_ty(); } +{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); } bool type::is_sized() const { // primitive types are sized @@ -120,7 +122,8 @@ bool type::is_sized() const { // primitive types type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; } type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; } -// half +// floating point +type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; } type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; } type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; } type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index c102f4a35..aa551a94b 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) { .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) .def("get_void", &ir::type::get_void_ty, ret::reference) + .def("get_fp8", &ir::type::get_fp8_ty, ret::reference) .def("get_fp16", &ir::type::get_half_ty, ret::reference) .def("get_fp32", &ir::type::get_float_ty, ret::reference) .def("get_fp64", &ir::type::get_double_ty, ret::reference) @@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) { .def("get_int64", &ir::type::get_int64_ty, ret::reference) .def("is_void", &ir::type::is_void_ty) + .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_half_ty) .def("is_fp32", &ir::type::is_float_ty) .def("is_fp64", &ir::type::is_double_ty) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 9c1df2839..7694b9ec9 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -2,7 +2,7 @@ # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner +from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret from . import language from . import code_gen diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 3a43acd16..897ec31e4 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -436,20 +436,23 @@ class CompilationError(Exception): class Kernel: - - type_names = { - int: 'I', - float: 'f', - bool: 'B', - torch.float16: 'f16', - torch.float32: 'f32', - torch.float64: 'f64', - torch.bool: 'i1', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - } + @staticmethod + def _type_name(obj): + type_names = { + int: 'I', + float: 'f', + bool: 'B', + triton.language.float8: 'f8', + torch.float16: 'f16', + torch.float32: 'f32', + torch.float64: 'f64', + torch.bool: 'i1', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + } + return type_names[obj] @staticmethod def _to_triton_ir(context, obj): @@ -457,6 +460,7 @@ class Kernel: 'I': _triton.ir.type.get_int32, 'f': _triton.ir.type.get_fp32, 'B': _triton.ir.type.get_int1, + 'f8': _triton.ir.type.get_fp8, 'f16': _triton.ir.type.get_fp16, 'f32': _triton.ir.type.get_fp32, 'f64': _triton.ir.type.get_fp64, @@ -467,12 +471,12 @@ class Kernel: 'i64': _triton.ir.type.get_int64, } # convert torch.Tensor to Triton IR pointers - if isinstance(obj, torch.Tensor): - name = Kernel.type_names[obj.dtype] + if hasattr(obj, 'data_ptr'): + name = Kernel._type_name(obj.dtype) elt_ty = type_map[name](context) return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - name = Kernel.type_names[obj.__class__] + name = Kernel._type_name(obj.__class__) return type_map[name](context) @staticmethod @@ -481,7 +485,7 @@ class Kernel: types_key = [None] * len(wargs) for i, arg in enumerate(wargs): prefix = 'P' if i in tensor_idxs else '' - suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__] + suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__) types_key[i] = prefix + suffix return tuple(types_key) @@ -523,7 +527,7 @@ class Kernel: def __call__(self, *wargs, grid, num_warps=4, **meta): # device inference - tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)] + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: raise ValueError("No Tensor argument found.") device = wargs[tensor_idxs[0]].device @@ -545,7 +549,7 @@ class Kernel: *wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta ) # pack arguments - fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)]) + fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) params = struct.pack(fmt, *args) # enqueue cached function into stream binary = cache[key] @@ -703,3 +707,19 @@ def jit(fn): def cdiv(x, y): return (x + y - 1) // y + + +###### + + +class TensorWrapper: + def __init__(self, data_ptr, dtype): + self._data_ptr = data_ptr + self.dtype = dtype + + def data_ptr(self): + return self._data_ptr + + +def reinterpret(tensor, dtype): + return TensorWrapper(tensor.data_ptr(), dtype) \ No newline at end of file diff --git a/python/triton/language.py b/python/triton/language.py index ccfda885a..87fc30a59 100644 --- a/python/triton/language.py +++ b/python/triton/language.py @@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8) int16 = dtype(ir.type.get_int16) int32 = dtype(ir.type.get_int32) int64 = dtype(ir.type.get_int64) +float8 = dtype(ir.type.get_fp8) float16 = dtype(ir.type.get_fp16) float32 = dtype(ir.type.get_fp32) float64 = dtype(ir.type.get_fp64) @@ -98,6 +99,7 @@ class block: if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 + if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64