uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -40,6 +40,8 @@ public:
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_uint32(uint32_t val);
|
||||
value *get_uint64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
@@ -50,6 +52,10 @@ public:
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_uint8_ty();
|
||||
type *get_uint16_ty();
|
||||
type *get_uint32_ty();
|
||||
type *get_uint64_ty();
|
||||
type *get_half_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
|
@@ -28,6 +28,7 @@ public:
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
// Block types
|
||||
|
@@ -15,6 +15,8 @@ class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
enum class signedness { SIGNED, UNSIGNED };
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
@@ -58,6 +60,8 @@ public:
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
signedness get_integer_signedness() const;
|
||||
bool is_integer_signed() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
@@ -80,8 +84,9 @@ public:
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_integer_ty(unsigned bitwidth, signedness sn) {
|
||||
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
|
||||
}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
@@ -109,6 +114,10 @@ public:
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
static integer_type *get_uint8_ty(context &ctx);
|
||||
static integer_type *get_uint16_ty(context &ctx);
|
||||
static integer_type *get_uint32_ty(context &ctx);
|
||||
static integer_type *get_uint64_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
@@ -135,7 +144,7 @@ public:
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
||||
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
@@ -158,18 +167,21 @@ class integer_type: public type {
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
|
||||
integer_type(context &ctx, unsigned bitwidth, signedness sn)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
signedness get_signedness() const { return signedness_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
signedness signedness_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
|
Reference in New Issue
Block a user