uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -51,9 +51,15 @@ value *builder::get_int1(bool val)
value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_uint32(uint32_t val)
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_uint64(uint64_t val)
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
value *builder::get_float16(float val)
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
@@ -84,6 +90,18 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); }
type *builder::get_uint8_ty()
{ return type::get_uint8_ty(ctx_); }
type *builder::get_uint16_ty()
{ return type::get_uint16_ty(ctx_); }
type *builder::get_uint32_ty()
{ return type::get_uint32_ty(ctx_); }
type *builder::get_uint64_ty()
{ return type::get_uint64_ty(ctx_); }
type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); }

View File

@@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx)
fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID),
// integers
int1_ty(ctx, 1),
int8_ty(ctx, 8),
int16_ty(ctx, 16),
int32_ty(ctx, 32),
int64_ty(ctx, 64),
int128_ty(ctx, 128){
int1_ty(ctx, 1, signedness::SIGNED),
int8_ty(ctx, 8, signedness::SIGNED),
int16_ty(ctx, 16, signedness::SIGNED),
int32_ty(ctx, 32, signedness::SIGNED),
int64_ty(ctx, 64, signedness::SIGNED),
int128_ty(ctx, 128, signedness::SIGNED),
uint8_ty(ctx, 8, signedness::UNSIGNED),
uint16_ty(ctx, 16, signedness::UNSIGNED),
uint32_ty(ctx, 32, signedness::UNSIGNED),
uint64_ty(ctx, 64, signedness::UNSIGNED){
}

View File

@@ -1,14 +1,12 @@
#include "triton/ir/dispatch.h"
#include <iostream>
namespace triton{
namespace ir{
namespace triton {
namespace ir {
ir::value* throw_unreachable(std::string key) {
[[ noreturn ]] void throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
return 0;
}
//===----------------------------------------------------------------------===//
@@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
return a_rank > b_rank ? a_ty : b_ty;
auto a_sn = a_ty->get_integer_signedness();
auto b_sn = b_ty->get_integer_signedness();
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if (a_sn == b_sn) {
return a_rank > b_rank ? a_ty : b_ty;
} else if (a_sn == signedness::UNSIGNED) {
return a_rank >= b_rank ? a_ty : b_ty;
} else if (b_sn == signedness::UNSIGNED) {
return b_rank >= a_rank ? b_ty : a_ty;
} else {
throw_unreachable("integer_promote");
}
}
enum class DivOrMod { NO, YES };
@@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
throw_unreachable("computation_type");
// 4 ) both operands are integer and undergo
// integer promotion
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
return integer_promote(a_ty, b_ty);
}
@@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
return throw_unreachable("add");
throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
return throw_unreachable("sub");
throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
return throw_unreachable("mul");
throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
}
// unreachable
else
return throw_unreachable("div");
throw_unreachable("div");
return builder->create_fdiv(input, other);
}
@@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
return builder->create_sdiv(input, other);
if (ret_ty->is_integer_signed()) {
return builder->create_sdiv(input, other);
} else {
return builder->create_udiv(input, other);
}
}
return throw_unreachable("floordiv");
throw_unreachable("floordiv");
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(input, other);
return throw_unreachable("mod");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
if (scalar_ty->is_integer_signed()) {
return builder->create_srem(input, other);
} else {
return builder->create_urem(input, other);
}
}
throw_unreachable("mod");
}
@@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
input = dispatch::cast(input, other_sca_ty, builder);
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
if (ret_sca_ty != input_sca_ty)
input = dispatch::cast(input, ret_sca_ty, builder);
if (ret_sca_ty != other_sca_ty)
other = dispatch::cast(other, ret_sca_ty, builder);
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(input, other);
return throw_unreachable("greater_than");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGT(input, other);
} else {
return builder->create_icmpUGT(input, other);
}
}
throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(input, other);
return throw_unreachable("greater_equal");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGE(input, other);
} else {
return builder->create_icmpUGE(input, other);
}
}
throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(input, other);
return throw_unreachable("less_than");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLT(input, other);
} else {
return builder->create_icmpULT(input, other);
}
}
throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(input, other);
return throw_unreachable("less_equal");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLE(input, other);
} else {
return builder->create_icmpULE(input, other);
}
}
throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
return throw_unreachable("equal");
throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
return throw_unreachable("equal");
throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
@@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
return builder->create_int_cast(input, dst_ty, sign_extend);
}
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
@@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if(src_sca_ty->is_bool_ty())
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
@@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
//===----------------------------------------------------------------------===//
@@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
}
}
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
@@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
// direct call to atomic_min for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
}
}
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
@@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
return throw_unreachable(name);
throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {

View File

@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
signedness type::get_integer_signedness() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
bool type::is_integer_signed() const {
if (id_ != IntegerTyID) {
throw std::logic_error("type is " + repr() + ", not integer");
}
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
}
unsigned type::get_tile_bitwidth() const
{ return ((block_type*)(this))->get_bitwidth(); }
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }