diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 357fffc6a..3a4094123 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -40,6 +40,8 @@ public: value *get_int1(bool val); value *get_int32(int32_t val); value *get_int64(int64_t val); + value *get_uint32(uint32_t val); + value *get_uint64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -50,6 +52,10 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); + type *get_uint8_ty(); + type *get_uint16_ty(); + type *get_uint32_ty(); + type *get_uint64_ty(); type *get_half_ty(); type *get_float_ty(); type *get_double_ty(); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 7d18a3b4c..e43b5ad57 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -28,6 +28,7 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; + integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, pointer_type*> ptr_tys; // Block types diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index c9c07c4f1..c27ce48cf 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -15,6 +15,8 @@ class value; class integer_type; class constant_int; +enum class signedness { SIGNED, UNSIGNED }; + /* Type */ class type { public: @@ -58,6 +60,8 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; + signedness get_integer_signedness() const; + bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -80,8 +84,9 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() && - get_integer_bitwidth() == bitwidth;} + bool is_integer_ty(unsigned bitwidth, signedness sn) { + return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; + } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -109,6 +114,10 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + static integer_type *get_uint8_ty(context &ctx); + static integer_type *get_uint16_ty(context &ctx); + static integer_type *get_uint32_ty(context &ctx); + static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -135,7 +144,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -158,18 +167,21 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth) - : type(ctx, IntegerTyID), bitwidth_(bitwidth){ } + integer_type(context &ctx, unsigned bitwidth, signedness sn) + : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } public: // accessors unsigned get_bitwidth() const { return bitwidth_; } + signedness get_signedness() const { return signedness_; } + // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; + signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index feac3c6b6..a8ba68d1c 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -51,9 +51,15 @@ value *builder::get_int1(bool val) value *builder::get_int32(int32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} +value *builder::get_uint32(uint32_t val) +{ return constant_int::get(type::get_uint32_ty(ctx_), val);} + value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} +value *builder::get_uint64(uint64_t val) +{ return constant_int::get(type::get_uint64_ty(ctx_), val);} + value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -84,6 +90,18 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } +type *builder::get_uint8_ty() +{ return type::get_uint8_ty(ctx_); } + +type *builder::get_uint16_ty() +{ return type::get_uint16_ty(ctx_); } + +type *builder::get_uint32_ty() +{ return type::get_uint32_ty(ctx_); } + +type *builder::get_uint64_ty() +{ return type::get_uint64_ty(ctx_); } + type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 9bd66ec9a..90b109b9b 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1), - int8_ty(ctx, 8), - int16_ty(ctx, 16), - int32_ty(ctx, 32), - int64_ty(ctx, 64), - int128_ty(ctx, 128){ + int1_ty(ctx, 1, signedness::SIGNED), + int8_ty(ctx, 8, signedness::SIGNED), + int16_ty(ctx, 16, signedness::SIGNED), + int32_ty(ctx, 32, signedness::SIGNED), + int64_ty(ctx, 64, signedness::SIGNED), + int128_ty(ctx, 128, signedness::SIGNED), + uint8_ty(ctx, 8, signedness::UNSIGNED), + uint16_ty(ctx, 16, signedness::UNSIGNED), + uint32_ty(ctx, 32, signedness::UNSIGNED), + uint64_ty(ctx, 64, signedness::UNSIGNED){ } diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index 2b8ae0259..a46c34b4f 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -1,14 +1,12 @@ #include "triton/ir/dispatch.h" -#include -namespace triton{ -namespace ir{ +namespace triton { +namespace ir { -ir::value* throw_unreachable(std::string key) { +[[ noreturn ]] void throw_unreachable(std::string key) { throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " "This is likely a bug on our side."); - return 0; } //===----------------------------------------------------------------------===// @@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) { ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ int a_rank = a_ty->get_integer_bitwidth(); int b_rank = b_ty->get_integer_bitwidth(); - return a_rank > b_rank ? a_ty : b_ty; + auto a_sn = a_ty->get_integer_signedness(); + auto b_sn = b_ty->get_integer_signedness(); + // Rules for signedness taken from "Usual arithmetic conversions" on + // https://en.cppreference.com/w/c/language/conversion. + if (a_sn == b_sn) { + return a_rank > b_rank ? a_ty : b_ty; + } else if (a_sn == signedness::UNSIGNED) { + return a_rank >= b_rank ? a_ty : b_ty; + } else if (b_sn == signedness::UNSIGNED) { + return b_rank >= a_rank ? b_ty : a_ty; + } else { + throw_unreachable("integer_promote"); + } } enum class DivOrMod { NO, YES }; @@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) throw_unreachable("computation_type"); // 4 ) both operands are integer and undergo // integer promotion + if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { + throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } return integer_promote(a_ty, b_ty); } @@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde // int + int else if (input_scalar_ty->is_integer_ty()) return builder->create_add(input, other); - return throw_unreachable("add"); + throw_unreachable("add"); } ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { @@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde // int + int else if (scalar_ty->is_integer_ty()) return builder->create_sub(input, other); - return throw_unreachable("sub"); + throw_unreachable("sub"); } ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { @@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde // int * int else if (scalar_ty->is_integer_ty()) return builder->create_mul(input, other); - return throw_unreachable("mul"); + throw_unreachable("mul"); } ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { @@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu } // unreachable else - return throw_unreachable("div"); + throw_unreachable("div"); return builder->create_fdiv(input, other); } @@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); input = dispatch::cast(input, ret_ty, builder); other = dispatch::cast(other, ret_ty, builder); - return builder->create_sdiv(input, other); + if (ret_ty->is_integer_signed()) { + return builder->create_sdiv(input, other); + } else { + return builder->create_udiv(input, other); + } } - return throw_unreachable("floordiv"); + throw_unreachable("floordiv"); } ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); // float % int if (scalar_ty->is_floating_point_ty()) return builder->create_frem(input, other); // int % int - else if (scalar_ty->is_integer_ty()) - return builder->create_srem(input, other); - return throw_unreachable("mod"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { + throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + if (scalar_ty->is_integer_signed()) { + return builder->create_srem(input, other); + } else { + return builder->create_urem(input, other); + } + } + throw_unreachable("mod"); } @@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) throw_incompatible_types(input_sca_ty, other_sca_ty); - if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth()) - input = dispatch::cast(input, other_sca_ty, builder); - else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth()) - other = dispatch::cast(other, input_sca_ty, builder); + ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); + if (ret_sca_ty != input_sca_ty) + input = dispatch::cast(input, ret_sca_ty, builder); + if (ret_sca_ty != other_sca_ty) + other = dispatch::cast(other, ret_sca_ty, builder); } ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { @@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOGT(input, other); // int > int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGT(input, other); - return throw_unreachable("greater_than"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGT(input, other); + } else { + return builder->create_icmpUGT(input, other); + } + } + throw_unreachable("greater_than"); } ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOGE(input, other); // int >= int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGE(input, other); - return throw_unreachable("greater_equal"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGE(input, other); + } else { + return builder->create_icmpUGE(input, other); + } + } + throw_unreachable("greater_equal"); } ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { @@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder * if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOLT(input, other); // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLT(input, other); - return throw_unreachable("less_than"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLT(input, other); + } else { + return builder->create_icmpULT(input, other); + } + } + throw_unreachable("less_than"); } ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder if (scalar_ty->is_floating_point_ty()) return builder->create_fcmpOLE(input, other); // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLE(input, other); - return throw_unreachable("less_equal"); + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLE(input, other); + } else { + return builder->create_icmpULE(input, other); + } + } + throw_unreachable("less_equal"); } ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpEQ(input, other); - return throw_unreachable("equal"); + throw_unreachable("equal"); } ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { @@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder * // int == int else if (scalar_ty->is_integer_ty()) return builder->create_icmpNE(input, other); - return throw_unreachable("equal"); + throw_unreachable("equal"); } //===----------------------------------------------------------------------===// @@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build return builder->create_fp_ext(input, dst_ty); // Int cast if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) - return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty()); + (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || + src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { + bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); + return builder->create_int_cast(input, dst_ty, sign_extend); + } // Float -> Int if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ if(dst_sca_ty->is_bool_ty()) @@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build } // int -> Float if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ - if(src_sca_ty->is_bool_ty()) + if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) return builder->create_ui_to_fp(input, dst_ty); else return builder->create_si_to_fp(input, dst_ty); @@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build other = builder->create_splat(other, src_ty->get_block_shapes()); return builder->create_icmpNE(input, other); } - return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); + throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); } //===----------------------------------------------------------------------===// @@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, atom_red_typechecking(ptr, val, mask, builder); ir::type* sca_ty = val->get_type()->get_scalar_ty(); // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); + } + } // for float // return atomic_smax(i_ptr, i_val) if val >= 0 // return atomic_umin(i_ptr, i_val) if val < 0 @@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ atom_red_typechecking(ptr, val, mask, builder); ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + // direct call to atomic_min for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); + } + } // for float // return atomic_smin(i_ptr, i_val) if val >= 0 // return atomic_umax(i_ptr, i_val) if val < 0 @@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder return builder->create_reduce(input, FLOAT_OP, axis); else if (scalar_ty->is_integer_ty()) return builder->create_reduce(input, INT_OP, axis); - return throw_unreachable(name); + throw_unreachable(name); } ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { diff --git a/lib/ir/type.cc b/lib/ir/type.cc index ab8acb24b..74066a65a 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } +signedness type::get_integer_signedness() const +{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } + +bool type::is_integer_signed() const { + if (id_ != IntegerTyID) { + throw std::logic_error("type is " + repr() + ", not integer"); + } + return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; +} + unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } +integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } +integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } +integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } +integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 783b0406a..4d7df76ff 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -109,6 +109,24 @@ std::string pow2_divisor(long N){ return "1"; } +// Returns something like "int16", whether dtype is a torch.dtype or +// triton.language.dtype. +std::string dtype_cache_key_part(const py::object& dtype) { + if (py::hasattr(dtype, "cache_key_part")) { + // Presumed to be a triton.language.dtype. + return std::string(py::str(py::getattr(dtype, "cache_key_part"))); + } else { + // Remove 'torch.' prefix from repr of torch.dtype. + py::object repr = py::repr(dtype); + size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr()); + const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); + if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) { + throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len)); + } + return std::string(repr_ptr + 6, repr_len - 6); + } +} + // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, @@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f cache_key += "1"; continue; } - // long and int have different kernels - if(!overflow & (std::abs(value) <= 0xffffffff)){ + // int32, uint32, int64, and uint64 have different kernels + if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) { cache_key += "int32"; params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); std::memcpy(params_ptr, &value, 4); params_ptr += 4; - } - else{ + } else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) { + cache_key += "uint32"; + params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); + std::memcpy(params_ptr, &value, 4); + params_ptr += 4; + } else if (!overflow) { cache_key += "int64"; params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); - if(overflow){ - unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr); - std::memcpy(&value, &uvalue, 8); - } std::memcpy(params_ptr, &value, 8); params_ptr += 8; + } else { + if (PyErr_Occurred()) { + throw std::logic_error("An error occurred?"); + } + unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr); + if (PyErr_Occurred()) { + throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg))); + } + cache_key += "uint64"; + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &unsigned_value, 8); + params_ptr += 8; } if(!specialize) continue; @@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); std::memcpy(params_ptr, &value, 8); params_ptr += 8; - py::object dtype = arg.attr("dtype"); - py::object repr = py::repr(dtype); - assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6)); - const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.' - size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6; - cache_key += std::string(start, len); + cache_key += dtype_cache_key_part(arg.attr("dtype")); cache_key += "*"; cache_key += "[multipleof("; cache_key += pow2_divisor(value); @@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) + .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) + .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) + .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) + .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) .def("is_void", &ir::type::is_void_ty) .def("is_fp8", &ir::type::is_fp8_ty) @@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) { .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) + .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) + .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) + .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) + .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) @@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) { .def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference) .def("get_int64", &ir::builder::get_int64, ret::reference) + .def("get_uint32", &ir::builder::get_uint32, ret::reference) + .def("get_uint64", &ir::builder::get_uint64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) .def("get_range", &ir::builder::get_range, ret::reference); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index fe33c9c6a..41c9e9236 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,7 +1,7 @@ import copy import itertools import re -from typing import Optional +from typing import Optional, Union import numpy as np import pytest @@ -10,17 +10,20 @@ from numpy.random import RandomState import triton import triton.language as tl +from triton.code_gen import TensorWrapper, reinterpret int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] -dtypes = int_dtypes + float_dtypes +dtypes = int_dtypes + uint_dtypes + float_dtypes + def _bitwidth(dtype: str) -> int: # ex.: "int64" -> 64 return int(re.search(r'(\d+)$', dtype).group(1)) -def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same result for both calls. @@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): if rs is None: rs = RandomState(seed=17) dtype = getattr(np, dtype_str) - if dtype_str in int_dtypes: + if dtype_str in int_dtypes + uint_dtypes: iinfo = np.iinfo(getattr(np, dtype_str)) - x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + x = rs.randint(low, high, shape, dtype=dtype) x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. return x elif dtype_str in float_dtypes: @@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None): raise RuntimeError(f'Unknown dtype {dtype_str}') -def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor: - # For now, this always converts to a torch tensor, but when we add unsigned - # integers, it will also support TensorWrapper, since torch doesn't have - # unsigned support. - return torch.tensor(x, device=device) +def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') def to_numpy(x): - if isinstance(x, torch.Tensor): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): return x.cpu().numpy() else: raise ValueError(f"Not a triton-compatible tensor: {x}") @@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: Given two dtype strings, returns the numpy dtype Triton thinks binary operations on the two types should return. Returns None if the return value matches numpy. This is generally needed because Triton and pytorch return - narrower floating point types than numpy in mixed operations. + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. """ overrides = { ('float16', 'int16'): np.float16, ('float16', 'int32'): np.float16, ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, } key = (a, b) if a < b else (b, a) return overrides.get(key) -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'): +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): SIZE = 128 # define the kernel / launch-grid @triton.jit @@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y= # inputs rs = RandomState(17) x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) - y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) if mode_x == 'nan': x[:] = float('nan') if mode_y == 'nan': @@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ('int64', 'float16'), ('int64', 'float32'), ('int64', 'float64'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), ] # --------------- @@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ]) def test_bin_op(dtype_x, dtype_y, op, device='cuda'): expr = f' x {op} y' - if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes: + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): @@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' else: numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): with pytest.raises(AssertionError, match='Not equal to tolerance'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + elif (op in ('%', '/') and + ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.code_gen.CompilationError) as exc_info: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) +@pytest.mark.parametrize("dtype_x, dtype_y", + [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] +) +def test_floordiv(dtype_x, dtype_y, device='cuda'): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + + # --------------- # test bitwise ops # --------------- @@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'): ]) def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None if 'float' in dtype_x + dtype_y: with pytest.raises(triton.code_gen.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: - _test_binary(dtype_x, dtype_y, expr, device=device) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) + + +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +def test_shift_op(dtype_x, dtype_y, op, device='cuda'): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) # --------------- @@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<='] ]) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): expr = f'x {op} y' - _test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device) + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) # --------------- @@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): # --------------- @pytest.mark.parametrize("dtype_x, expr", [ (dtype_x, ' -x') for dtype_x in dtypes -] + [\ +] + [ (dtype_x, ' ~x') for dtype_x in int_dtypes - ]) +]) def test_unary_op(dtype_x, expr, device='cuda'): _test_unary(dtype_x, expr, device=device) @@ -275,8 +367,9 @@ def make_ptr_str(name, shape): @pytest.mark.parametrize("expr, dtype_str", [ - (f'x[{s}]', 'int32') + (f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): rank_x = expr.count(':') @@ -364,9 +457,9 @@ def test_tuples(): @pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ [ ('add', 'float16', mode), - ('add', 'int32', mode), ('add', 'float32', mode), - ('max', 'int32', mode), ('max', 'float32', mode), - ('min', 'int32', mode), ('min', 'float32', mode), + ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode), + ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode), + ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): @@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): if exact: assert z_ref.item() == to_numpy(z_tri).item() else: - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # --------------- @@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): ('float32', 'bfloat16', False), ('bfloat16', 'float32', False), ('float32', 'int32', True), -] -) +] + [ + (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] +] + [ + (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] +]) def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. x0 = 43 if dtype_x in int_dtypes else 43.5 @@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, axis", [ - ('float32', (1, 1024), 1) + (dtype, (1, 1024), 1) for dtype in ['float32', 'uint32'] ]) def test_reduce2d(dtype_str, shape, axis, device='cuda'): # triton kernel @@ -762,3 +858,43 @@ def test_noop(device='cuda'): pass x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [ + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'), + (2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + pgm = kernel[(1, )](value, x) + + # Parse out the type of the 'VALUE' parameter from the Triton IR. + triton_ir = pgm.asm['ttir'] + ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + assert ir_value_type == value_type + + +@pytest.mark.parametrize( + "value, overflow", + [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] +) +def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + + if overflow: + with pytest.raises(RuntimeError, match='integer overflow'): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index 4d4501556..67173adfb 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'): N = x.numel() grid = (triton.cdiv(N, BLOCK),) kernel[grid](x, N, seed) + assert all((x >= 0) & (x <= 1)) assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 # test normal PRNG diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 8393f2b87..eec36f052 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor): return triton.language.constexpr(not op) if isinstance(op, triton.language.core.constexpr): op = op.value - # print(op) fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -503,6 +502,7 @@ class Binary: self.shared_mem = shared_mem self.num_warps = num_warps + class LoadedBinary: def __init__(self, device: int, bin: Binary): module, kernel = _triton.code_gen.load_binary(bin.backend, @@ -571,24 +571,33 @@ class Kernel: torch.int16: 'i16', torch.int32: 'i32', torch.int64: 'i64', + triton.language.uint8: 'u8', + triton.language.uint16: 'u16', + triton.language.uint32: 'u32', + triton.language.uint64: 'u64', } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] if isinstance(obj, triton.language.core.constexpr): obj = obj.value if isinstance(obj, int): - if abs(obj) <= 0xffffffff: - return 'I' - return 'L' + if -2**31 <= obj < 2**31: + return 'i32' + elif 2**31 <= obj < 2**32: + return 'u32' + elif -2**63 <= obj < 2**63: + return 'i64' + elif 2**63 <= obj < 2**64: + return 'u64' + else: + raise ValueError(f'integer overflow representing {obj}') if isinstance(obj, float): return 'f' if isinstance(obj, bool): return 'B' if isinstance(obj, str): return 'str' - assert False - - + raise NotImplementedError(f'could not compute type name for {obj}') @staticmethod def _to_triton_ir(context, obj): @@ -607,6 +616,10 @@ class Kernel: 'i16': _triton.ir.type.get_int16, 'i32': _triton.ir.type.get_int32, 'i64': _triton.ir.type.get_int64, + 'u8': _triton.ir.type.get_uint8, + 'u16': _triton.ir.type.get_uint16, + 'u32': _triton.ir.type.get_uint32, + 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers if hasattr(obj, 'data_ptr'): @@ -1165,4 +1178,15 @@ class TensorWrapper: def reinterpret(tensor, dtype): - return TensorWrapper(tensor, dtype) + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d5d3313e5..210a72a30 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -9,9 +9,16 @@ def _to_ir(x, builder): if isinstance(x, bool): return builder.get_int1(x) elif isinstance(x, int): - if x.__abs__() <= 2**31: + if -2**31 <= x < 2**31: return builder.get_int32(x) - return builder.get_int64(x) + elif 2**31 <= x < 2**32: + return builder.get_uint32(x) + elif -2**63 <= x < 2**63: + return builder.get_int64(x) + elif 2**63 <= x < 2**64: + return builder.get_uint64(x) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): return builder.get_float32(x) elif isinstance(x, constexpr): @@ -83,6 +90,14 @@ class dtype: def __str__(self): return self.name + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + return f'triton.language.{self.name}' + class pointer_dtype: def __init__(self, element_ty): @@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8) int16 = dtype(ir.type.get_int16) int32 = dtype(ir.type.get_int32) int64 = dtype(ir.type.get_int64) +uint8 = dtype(ir.type.get_uint8) +uint16 = dtype(ir.type.get_uint16) +uint32 = dtype(ir.type.get_uint32) +uint64 = dtype(ir.type.get_uint64) float8 = dtype(ir.type.get_fp8) float16 = dtype(ir.type.get_fp16) bfloat16 = dtype(ir.type.get_bf16) @@ -120,6 +139,10 @@ class block: if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 + if ir_type.is_uint8(): return uint8 + if ir_type.is_uint16(): return uint16 + if ir_type.is_uint32(): return uint32 + if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16