uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -40,6 +40,8 @@ public:
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_uint32(uint32_t val);
value *get_uint64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
@@ -50,6 +52,10 @@ public:
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_uint8_ty();
type *get_uint16_ty();
type *get_uint32_ty();
type *get_uint64_ty();
type *get_half_ty();
type *get_float_ty();
type *get_double_ty();

View File

@@ -28,6 +28,7 @@ public:
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
// Block types

View File

@@ -15,6 +15,8 @@ class value;
class integer_type;
class constant_int;
enum class signedness { SIGNED, UNSIGNED };
/* Type */
class type {
public:
@@ -58,6 +60,8 @@ public:
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
signedness get_integer_signedness() const;
bool is_integer_signed() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
@@ -80,8 +84,9 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_integer_ty(unsigned bitwidth, signedness sn) {
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
@@ -109,6 +114,10 @@ public:
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
static integer_type *get_uint8_ty(context &ctx);
static integer_type *get_uint16_ty(context &ctx);
static integer_type *get_uint32_ty(context &ctx);
static integer_type *get_uint64_ty(context &ctx);
// repr
std::string tile_repr() const {
@@ -135,7 +144,7 @@ public:
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
@@ -158,18 +167,21 @@ class integer_type: public type {
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
integer_type(context &ctx, unsigned bitwidth, signedness sn)
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
signedness get_signedness() const { return signedness_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
signedness signedness_;
};
class composite_type: public type{