[LANG] Preliminary FP8 support (#96)

This commit is contained in:
Philippe Tillet
2021-05-01 14:34:33 -04:00
committed by Philippe Tillet
parent 4290be1ae8
commit 7355efa745
10 changed files with 182 additions and 40 deletions

View File

@@ -102,6 +102,10 @@ public:
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);

View File

@@ -22,9 +22,11 @@ public:
context_impl(context &ctx);
public:
// primitive types
type void_ty, label_ty, half_ty, float_ty, double_ty;
// derived types
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, half_ty, float_ty, double_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;

View File

@@ -29,21 +29,22 @@ public:
enum id_t {
// primitive types
VoidTyID = 0, ///< 0: type with no size
HalfTyID, ///< 1: 16-bit floating point type
FloatTyID, ///< 2: 32-bit floating point type
DoubleTyID, ///< 3: 64-bit floating point type
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 7: Labels
MetadataTyID, ///< 8: Metadata
TokenTyID, ///< 9: Token
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
HalfTyID, ///< 3: 16-bit floating point type
FloatTyID, ///< 4: 32-bit floating point type
DoubleTyID, ///< 5: 64-bit floating point type
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
LabelTyID, ///< 9: Labels
MetadataTyID, ///< 10: Metadata
TokenTyID, ///< 11: Token
// derived types
IntegerTyID, ///< 10: Arbitrary bit width integers
FunctionTyID, ///< 11: Functions
PointerTyID, ///< 12: Pointers
StructTyID, ///< 13: Struct
BlockTyID, ///< 14: Tile
IntegerTyID, ///< 12: Arbitrary bit width integers
FunctionTyID, ///< 13: Functions
PointerTyID, ///< 14: Pointers
StructTyID, ///< 15: Struct
BlockTyID, ///< 16: Block
};
public:
@@ -72,6 +73,7 @@ public:
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
@@ -96,6 +98,7 @@ public:
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_fp8_ty(context &ctx);
static type *get_half_ty(context &ctx);
static type *get_float_ty(context &ctx);
static type *get_double_ty(context &ctx);
@@ -124,6 +127,7 @@ public:
std::string repr() const {
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case HalfTyID: return "f16";
case FloatTyID: return "f32";
case DoubleTyID: return "f64";

View File

@@ -27,6 +27,7 @@ using namespace llvm;
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
@@ -60,6 +61,7 @@ using namespace llvm;
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define mul(...) builder_->CreateMul(__VA_ARGS__)
@@ -69,6 +71,7 @@ using namespace llvm;
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
@@ -101,6 +104,7 @@ Type *generator::cvt(ir::type *ty) {
// primitive types
switch(ty->get_type_id()){
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
@@ -316,10 +320,108 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
}
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false),
"{ \n\t"
".reg .b32 b<4>; \n\t"
"shl.b32 b0, $1, 4; \n\t" // shift into into upper byte
"shl.b32 b1, $2, 4; \n\t"
"shl.b32 b2, $3, 4; \n\t"
"shl.b32 b3, $4, 4; \n\t"
"lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign
"lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t"
"lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t"
"lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t"
"prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here)
"prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here)
"prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3
"}", "=r, r, r, r, r", false);
Value *packed_ret = call(ptx, {in0, in1, in2, in3});
Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4));
return std::make_tuple(extract_elt(ret, (int)0),
extract_elt(ret, (int)1),
extract_elt(ret, (int)2),
extract_elt(ret, (int)3));
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
Value *ret0, *ret1, *ret2, *ret3;
std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
"{"
".reg .b32 a<2>, b<2>; \n\t"
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
"shr.b32 b1, b1, 1; \n\t"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
"}", "=r,=r,r", false);
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
packed_in = insert_elt(packed_in, in0, (int)0);
packed_in = insert_elt(packed_in, in1, (int)1);
packed_in = insert_elt(packed_in, in2, (int)2);
packed_in = insert_elt(packed_in, in3, (int)3);
Value *in = bit_cast(packed_in, i32_ty);
Value *ret = call(ptx, {in});
Value *packed_ret0 = extract_val(ret, {0});
Value *packed_ret1 = extract_val(ret, {1});
Value *ret0 = extract_elt(packed_ret0, (int)0);
Value *ret1 = extract_elt(packed_ret0, (int)1);
Value *ret2 = extract_elt(packed_ret1, (int)0);
Value *ret3 = extract_elt(packed_ret1, (int)1);
return std::make_tuple(ret0, ret1, ret2, ret3);
}
/**
* \brief Code Generation for `cast`
*/
void generator::visit_cast_inst(ir::cast_inst* x) {
// <> FP8
ir::value *op = x->get_operand(0);
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
// ensure that conversions can be vectorized
int ld = layouts_->get(x)->get_order(0);
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
if(contiguous % 4 != 0)
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
auto x_idxs = idxs_.at(x);
auto op_idxs = idxs_.at(op);
// run the conversion
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
return fp8x4_to_fp16x4(a, b, c, d);
throw std::runtime_error("unsupported conversion");
};
for(size_t i = 0; i < x_idxs.size(); i+=4){
std::tie(vals_[x][x_idxs[i+0]],
vals_[x][x_idxs[i+1]],
vals_[x][x_idxs[i+2]],
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
vals_[op][op_idxs[i+1]],
vals_[op][op_idxs[i+2]],
vals_[op][op_idxs[i+3]]);
}
return;
}
Type *ty = cvt(x->get_type()->get_scalar_ty());
auto cvt = [](ir::cast_op_t op){
using ll = llvm::Instruction::CastOps;

View File

@@ -12,9 +12,12 @@ namespace ir{
context_impl::context_impl(context &ctx)
: void_ty(ctx, type::VoidTyID),
label_ty(ctx, type::LabelTyID),
// floating point
fp8_ty(ctx, type::FP8TyID),
half_ty(ctx, type::HalfTyID),
float_ty(ctx, type::FloatTyID),
double_ty(ctx, type::DoubleTyID),
// integers
int1_ty(ctx, 1),
int8_ty(ctx, 8),
int16_ty(ctx, 16),

View File

@@ -21,6 +21,7 @@ type *type::get_scalar_ty() const {
unsigned type::get_primitive_size_in_bits() const {
switch (id_) {
case FP8TyID: return 8;
case HalfTyID: return 16;
case FloatTyID: return 32;
case DoubleTyID: return 64;
@@ -42,6 +43,7 @@ unsigned type::get_tile_bitwidth() const
unsigned type::get_fp_mantissa_width() const {
id_t id = get_scalar_ty()->id_;
assert(is_floating_point_ty() && "Not a floating point type!");
if (id == FP8TyID) return 3;
if (id == HalfTyID) return 10;
if (id == FloatTyID) return 23;
if (id == DoubleTyID) return 53;
@@ -103,7 +105,7 @@ bool type::is_integer_ty(unsigned width) const
bool type::is_floating_point_ty() const
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
bool type::is_sized() const {
// primitive types are sized
@@ -120,7 +122,8 @@ bool type::is_sized() const {
// primitive types
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
// half
// floating point
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }

View File

@@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) {
.def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference)
.def("get_void", &ir::type::get_void_ty, ret::reference)
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
@@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) {
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp8", &ir::type::is_fp8_ty)
.def("is_fp16", &ir::type::is_half_ty)
.def("is_fp32", &ir::type::is_float_ty)
.def("is_fp64", &ir::type::is_double_ty)

View File

@@ -2,7 +2,7 @@
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch
# submodules
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret
from . import language
from . import code_gen

View File

@@ -436,20 +436,23 @@ class CompilationError(Exception):
class Kernel:
type_names = {
int: 'I',
float: 'f',
bool: 'B',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
@staticmethod
def _type_name(obj):
type_names = {
int: 'I',
float: 'f',
bool: 'B',
triton.language.float8: 'f8',
torch.float16: 'f16',
torch.float32: 'f32',
torch.float64: 'f64',
torch.bool: 'i1',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
}
return type_names[obj]
@staticmethod
def _to_triton_ir(context, obj):
@@ -457,6 +460,7 @@ class Kernel:
'I': _triton.ir.type.get_int32,
'f': _triton.ir.type.get_fp32,
'B': _triton.ir.type.get_int1,
'f8': _triton.ir.type.get_fp8,
'f16': _triton.ir.type.get_fp16,
'f32': _triton.ir.type.get_fp32,
'f64': _triton.ir.type.get_fp64,
@@ -467,12 +471,12 @@ class Kernel:
'i64': _triton.ir.type.get_int64,
}
# convert torch.Tensor to Triton IR pointers
if isinstance(obj, torch.Tensor):
name = Kernel.type_names[obj.dtype]
if hasattr(obj, 'data_ptr'):
name = Kernel._type_name(obj.dtype)
elt_ty = type_map[name](context)
return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly
name = Kernel.type_names[obj.__class__]
name = Kernel._type_name(obj.__class__)
return type_map[name](context)
@staticmethod
@@ -481,7 +485,7 @@ class Kernel:
types_key = [None] * len(wargs)
for i, arg in enumerate(wargs):
prefix = 'P' if i in tensor_idxs else ''
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
types_key[i] = prefix + suffix
return tuple(types_key)
@@ -523,7 +527,7 @@ class Kernel:
def __call__(self, *wargs, grid, num_warps=4, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
raise ValueError("No Tensor argument found.")
device = wargs[tensor_idxs[0]].device
@@ -545,7 +549,7 @@ class Kernel:
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
params = struct.pack(fmt, *args)
# enqueue cached function into stream
binary = cache[key]
@@ -703,3 +707,19 @@ def jit(fn):
def cdiv(x, y):
return (x + y - 1) // y
######
class TensorWrapper:
def __init__(self, data_ptr, dtype):
self._data_ptr = data_ptr
self.dtype = dtype
def data_ptr(self):
return self._data_ptr
def reinterpret(tensor, dtype):
return TensorWrapper(tensor.data_ptr(), dtype)

View File

@@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
float8 = dtype(ir.type.get_fp8)
float16 = dtype(ir.type.get_fp16)
float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64)
@@ -98,6 +99,7 @@ class block:
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_fp32(): return float32
if ir_type.is_fp64(): return float64