[LANG] Preliminary FP8 support (#96)
This commit is contained in:
committed by
Philippe Tillet
parent
4290be1ae8
commit
7355efa745
@@ -102,6 +102,10 @@ public:
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
|
@@ -22,9 +22,11 @@ public:
|
||||
context_impl(context &ctx);
|
||||
|
||||
public:
|
||||
// primitive types
|
||||
type void_ty, label_ty, half_ty, float_ty, double_ty;
|
||||
// derived types
|
||||
// non-numeric types
|
||||
type void_ty, label_ty;
|
||||
// floating point types
|
||||
type fp8_ty, half_ty, float_ty, double_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
|
@@ -29,21 +29,22 @@ public:
|
||||
enum id_t {
|
||||
// primitive types
|
||||
VoidTyID = 0, ///< 0: type with no size
|
||||
HalfTyID, ///< 1: 16-bit floating point type
|
||||
FloatTyID, ///< 2: 32-bit floating point type
|
||||
DoubleTyID, ///< 3: 64-bit floating point type
|
||||
X86_FP80TyID, ///< 4: 80-bit floating point type (X87)
|
||||
FP128TyID, ///< 5: 128-bit floating point type (112-bit mantissa)
|
||||
PPC_FP128TyID, ///< 6: 128-bit floating point type (two 64-bits, PowerPC)
|
||||
LabelTyID, ///< 7: Labels
|
||||
MetadataTyID, ///< 8: Metadata
|
||||
TokenTyID, ///< 9: Token
|
||||
FP8TyID, ///< 1: 8-bit floating point type (3 bits mantissa)
|
||||
HalfTyID, ///< 3: 16-bit floating point type
|
||||
FloatTyID, ///< 4: 32-bit floating point type
|
||||
DoubleTyID, ///< 5: 64-bit floating point type
|
||||
X86_FP80TyID, ///< 6: 80-bit floating point type (X87)
|
||||
FP128TyID, ///< 7: 128-bit floating point type (112-bit mantissa)
|
||||
PPC_FP128TyID, ///< 8: 128-bit floating point type (two 64-bits, PowerPC)
|
||||
LabelTyID, ///< 9: Labels
|
||||
MetadataTyID, ///< 10: Metadata
|
||||
TokenTyID, ///< 11: Token
|
||||
// derived types
|
||||
IntegerTyID, ///< 10: Arbitrary bit width integers
|
||||
FunctionTyID, ///< 11: Functions
|
||||
PointerTyID, ///< 12: Pointers
|
||||
StructTyID, ///< 13: Struct
|
||||
BlockTyID, ///< 14: Tile
|
||||
IntegerTyID, ///< 12: Arbitrary bit width integers
|
||||
FunctionTyID, ///< 13: Functions
|
||||
PointerTyID, ///< 14: Pointers
|
||||
StructTyID, ///< 15: Struct
|
||||
BlockTyID, ///< 16: Block
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -72,6 +73,7 @@ public:
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
||||
@@ -96,6 +98,7 @@ public:
|
||||
static type *get_void_ty(context &ctx);
|
||||
static type *get_label_ty(context &ctx);
|
||||
// half
|
||||
static type *get_fp8_ty(context &ctx);
|
||||
static type *get_half_ty(context &ctx);
|
||||
static type *get_float_ty(context &ctx);
|
||||
static type *get_double_ty(context &ctx);
|
||||
@@ -124,6 +127,7 @@ public:
|
||||
std::string repr() const {
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case FP8TyID: return "fp8";
|
||||
case HalfTyID: return "f16";
|
||||
case FloatTyID: return "f32";
|
||||
case DoubleTyID: return "f64";
|
||||
|
@@ -27,6 +27,7 @@ using namespace llvm;
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
@@ -60,6 +61,7 @@ using namespace llvm;
|
||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define mul(...) builder_->CreateMul(__VA_ARGS__)
|
||||
@@ -69,6 +71,7 @@ using namespace llvm;
|
||||
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
||||
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
||||
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
||||
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
||||
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
||||
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||
@@ -101,6 +104,7 @@ Type *generator::cvt(ir::type *ty) {
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
|
||||
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(*ctx_);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(*ctx_);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_);
|
||||
@@ -316,10 +320,108 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* x) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(i32_ty, {f32_ty, f32_ty, f32_ty, f32_ty}, false),
|
||||
"{ \n\t"
|
||||
".reg .b32 b<4>; \n\t"
|
||||
"shl.b32 b0, $1, 4; \n\t" // shift into into upper byte
|
||||
"shl.b32 b1, $2, 4; \n\t"
|
||||
"shl.b32 b2, $3, 4; \n\t"
|
||||
"shl.b32 b3, $4, 4; \n\t"
|
||||
"lop3.b32 b0, b0, 0x80000000, $1, 0xb8; \n\t" // restore sign
|
||||
"lop3.b32 b1, b1, 0x80000000, $2, 0xb8; \n\t"
|
||||
"lop3.b32 b2, b2, 0x80000000, $3, 0xb8; \n\t"
|
||||
"lop3.b32 b3, b3, 0x80000000, $4, 0xb8; \n\t"
|
||||
"prmt.b32 b0, b0, b1, 0x6273; \n\t" // pack lower half b0, b1 (62 unused here)
|
||||
"prmt.b32 b2, b2, b3, 0x6273; \n\t" // pack lower half b2, b3 (62 unused here)
|
||||
"prmt.b32 $0, b0, b2, 0x5410; \n\t" // pack full b0, b1, b2, b3
|
||||
"}", "=r, r, r, r, r", false);
|
||||
Value *packed_ret = call(ptx, {in0, in1, in2, in3});
|
||||
Value* ret = bit_cast(packed_ret, vec_ty(i8_ty, 4));
|
||||
return std::make_tuple(extract_elt(ret, (int)0),
|
||||
extract_elt(ret, (int)1),
|
||||
extract_elt(ret, (int)2),
|
||||
extract_elt(ret, (int)3));
|
||||
}
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
Value *ret0, *ret1, *ret2, *ret3;
|
||||
std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3);
|
||||
ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty);
|
||||
ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty);
|
||||
ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty);
|
||||
ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){
|
||||
Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)});
|
||||
InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false),
|
||||
"{"
|
||||
".reg .b32 a<2>, b<2>; \n\t"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n\t"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // strip sign
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t"
|
||||
"shr.b32 b0, b0, 1; \n\t" // shift into fp16 poistion
|
||||
"shr.b32 b1, b1, 1; \n\t"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // restore sign
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t"
|
||||
"}", "=r,=r,r", false);
|
||||
Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
packed_in = insert_elt(packed_in, in0, (int)0);
|
||||
packed_in = insert_elt(packed_in, in1, (int)1);
|
||||
packed_in = insert_elt(packed_in, in2, (int)2);
|
||||
packed_in = insert_elt(packed_in, in3, (int)3);
|
||||
Value *in = bit_cast(packed_in, i32_ty);
|
||||
Value *ret = call(ptx, {in});
|
||||
Value *packed_ret0 = extract_val(ret, {0});
|
||||
Value *packed_ret1 = extract_val(ret, {1});
|
||||
Value *ret0 = extract_elt(packed_ret0, (int)0);
|
||||
Value *ret1 = extract_elt(packed_ret0, (int)1);
|
||||
Value *ret2 = extract_elt(packed_ret1, (int)0);
|
||||
Value *ret3 = extract_elt(packed_ret1, (int)1);
|
||||
return std::make_tuple(ret0, ret1, ret2, ret3);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Code Generation for `cast`
|
||||
*/
|
||||
void generator::visit_cast_inst(ir::cast_inst* x) {
|
||||
// <> FP8
|
||||
ir::value *op = x->get_operand(0);
|
||||
ir::type* ret_sca_ty = x->get_type()->get_scalar_ty();
|
||||
ir::type* op_sca_ty = op->get_type()->get_scalar_ty();
|
||||
if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){
|
||||
// ensure that conversions can be vectorized
|
||||
int ld = layouts_->get(x)->get_order(0);
|
||||
int contiguous = layouts_->get(x)->to_scanline()->nts(ld);
|
||||
if(contiguous % 4 != 0)
|
||||
throw std::runtime_error("unsupported fp32 -> fp8 conversion");
|
||||
auto x_idxs = idxs_.at(x);
|
||||
auto op_idxs = idxs_.at(op);
|
||||
// run the conversion
|
||||
auto cvt = [&](Value* a, Value* b, Value* c, Value* d){
|
||||
if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_half_ty())
|
||||
return fp8x4_to_fp16x4(a, b, c, d);
|
||||
throw std::runtime_error("unsupported conversion");
|
||||
};
|
||||
for(size_t i = 0; i < x_idxs.size(); i+=4){
|
||||
std::tie(vals_[x][x_idxs[i+0]],
|
||||
vals_[x][x_idxs[i+1]],
|
||||
vals_[x][x_idxs[i+2]],
|
||||
vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]],
|
||||
vals_[op][op_idxs[i+1]],
|
||||
vals_[op][op_idxs[i+2]],
|
||||
vals_[op][op_idxs[i+3]]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||
auto cvt = [](ir::cast_op_t op){
|
||||
using ll = llvm::Instruction::CastOps;
|
||||
|
@@ -12,9 +12,12 @@ namespace ir{
|
||||
context_impl::context_impl(context &ctx)
|
||||
: void_ty(ctx, type::VoidTyID),
|
||||
label_ty(ctx, type::LabelTyID),
|
||||
// floating point
|
||||
fp8_ty(ctx, type::FP8TyID),
|
||||
half_ty(ctx, type::HalfTyID),
|
||||
float_ty(ctx, type::FloatTyID),
|
||||
double_ty(ctx, type::DoubleTyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
|
@@ -21,6 +21,7 @@ type *type::get_scalar_ty() const {
|
||||
|
||||
unsigned type::get_primitive_size_in_bits() const {
|
||||
switch (id_) {
|
||||
case FP8TyID: return 8;
|
||||
case HalfTyID: return 16;
|
||||
case FloatTyID: return 32;
|
||||
case DoubleTyID: return 64;
|
||||
@@ -42,6 +43,7 @@ unsigned type::get_tile_bitwidth() const
|
||||
unsigned type::get_fp_mantissa_width() const {
|
||||
id_t id = get_scalar_ty()->id_;
|
||||
assert(is_floating_point_ty() && "Not a floating point type!");
|
||||
if (id == FP8TyID) return 3;
|
||||
if (id == HalfTyID) return 10;
|
||||
if (id == FloatTyID) return 23;
|
||||
if (id == DoubleTyID) return 53;
|
||||
@@ -103,7 +105,7 @@ bool type::is_integer_ty(unsigned width) const
|
||||
|
||||
|
||||
bool type::is_floating_point_ty() const
|
||||
{ return is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
{ return is_fp8_ty() || is_half_ty() || is_float_ty() || is_double_ty(); }
|
||||
|
||||
bool type::is_sized() const {
|
||||
// primitive types are sized
|
||||
@@ -120,7 +122,8 @@ bool type::is_sized() const {
|
||||
// primitive types
|
||||
type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
|
||||
type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
|
||||
// half
|
||||
// floating point
|
||||
type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
|
||||
type *type::get_half_ty(context &ctx) { return &ctx.p_impl->half_ty; }
|
||||
type *type::get_float_ty(context &ctx) { return &ctx.p_impl->float_ty; }
|
||||
type *type::get_double_ty(context &ctx) { return &ctx.p_impl->double_ty; }
|
||||
|
@@ -193,6 +193,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
.def("get_void", &ir::type::get_void_ty, ret::reference)
|
||||
.def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
|
||||
.def("get_fp16", &ir::type::get_half_ty, ret::reference)
|
||||
.def("get_fp32", &ir::type::get_float_ty, ret::reference)
|
||||
.def("get_fp64", &ir::type::get_double_ty, ret::reference)
|
||||
@@ -203,6 +204,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_half_ty)
|
||||
.def("is_fp32", &ir::type::is_float_ty)
|
||||
.def("is_fp64", &ir::type::is_double_ty)
|
||||
|
@@ -2,7 +2,7 @@
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner
|
||||
from .code_gen import cdiv, jit, autotune, heuristics, Config, Autotuner, reinterpret
|
||||
|
||||
from . import language
|
||||
from . import code_gen
|
||||
|
@@ -436,20 +436,23 @@ class CompilationError(Exception):
|
||||
|
||||
|
||||
class Kernel:
|
||||
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
torch.bool: 'i1',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
@staticmethod
|
||||
def _type_name(obj):
|
||||
type_names = {
|
||||
int: 'I',
|
||||
float: 'f',
|
||||
bool: 'B',
|
||||
triton.language.float8: 'f8',
|
||||
torch.float16: 'f16',
|
||||
torch.float32: 'f32',
|
||||
torch.float64: 'f64',
|
||||
torch.bool: 'i1',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
}
|
||||
return type_names[obj]
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
@@ -457,6 +460,7 @@ class Kernel:
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
@@ -467,12 +471,12 @@ class Kernel:
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if isinstance(obj, torch.Tensor):
|
||||
name = Kernel.type_names[obj.dtype]
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
name = Kernel._type_name(obj.dtype)
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
name = Kernel.type_names[obj.__class__]
|
||||
name = Kernel._type_name(obj.__class__)
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
@@ -481,7 +485,7 @@ class Kernel:
|
||||
types_key = [None] * len(wargs)
|
||||
for i, arg in enumerate(wargs):
|
||||
prefix = 'P' if i in tensor_idxs else ''
|
||||
suffix = Kernel.type_names[arg.dtype] if i in tensor_idxs else Kernel.type_names[arg.__class__]
|
||||
suffix = Kernel._type_name(arg.dtype) if i in tensor_idxs else Kernel._type_name(arg.__class__)
|
||||
types_key[i] = prefix + suffix
|
||||
return tuple(types_key)
|
||||
|
||||
@@ -523,7 +527,7 @@ class Kernel:
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
||||
# device inference
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if isinstance(arg, torch.Tensor)]
|
||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||
if len(tensor_idxs) == 0:
|
||||
raise ValueError("No Tensor argument found.")
|
||||
device = wargs[tensor_idxs[0]].device
|
||||
@@ -545,7 +549,7 @@ class Kernel:
|
||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
||||
)
|
||||
# pack arguments
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel.type_names[arg.__class__] for i, arg in enumerate(wargs)])
|
||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
@@ -703,3 +707,19 @@ def jit(fn):
|
||||
|
||||
def cdiv(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
######
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, data_ptr, dtype):
|
||||
self._data_ptr = data_ptr
|
||||
self.dtype = dtype
|
||||
|
||||
def data_ptr(self):
|
||||
return self._data_ptr
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor.data_ptr(), dtype)
|
@@ -84,6 +84,7 @@ int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
float8 = dtype(ir.type.get_fp8)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
@@ -98,6 +99,7 @@ class block:
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_fp32(): return float32
|
||||
if ir_type.is_fp64(): return float64
|
||||
|
Reference in New Issue
Block a user