uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -51,9 +51,15 @@ value *builder::get_int1(bool val)
|
||||
value *builder::get_int32(int32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_uint32(uint32_t val)
|
||||
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(int64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_uint64(uint64_t val)
|
||||
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
@@ -84,6 +90,18 @@ type *builder::get_int32_ty()
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint8_ty()
|
||||
{ return type::get_uint8_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint16_ty()
|
||||
{ return type::get_uint16_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint32_ty()
|
||||
{ return type::get_uint32_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint64_ty()
|
||||
{ return type::get_uint64_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
|
@@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx)
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128){
|
||||
int1_ty(ctx, 1, signedness::SIGNED),
|
||||
int8_ty(ctx, 8, signedness::SIGNED),
|
||||
int16_ty(ctx, 16, signedness::SIGNED),
|
||||
int32_ty(ctx, 32, signedness::SIGNED),
|
||||
int64_ty(ctx, 64, signedness::SIGNED),
|
||||
int128_ty(ctx, 128, signedness::SIGNED),
|
||||
uint8_ty(ctx, 8, signedness::UNSIGNED),
|
||||
uint16_ty(ctx, 16, signedness::UNSIGNED),
|
||||
uint32_ty(ctx, 32, signedness::UNSIGNED),
|
||||
uint64_ty(ctx, 64, signedness::UNSIGNED){
|
||||
|
||||
}
|
||||
|
||||
|
@@ -1,14 +1,12 @@
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
|
||||
|
||||
ir::value* throw_unreachable(std::string key) {
|
||||
[[ noreturn ]] void throw_unreachable(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
|
||||
"This is likely a bug on our side.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
|
||||
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
||||
int a_rank = a_ty->get_integer_bitwidth();
|
||||
int b_rank = b_ty->get_integer_bitwidth();
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
auto a_sn = a_ty->get_integer_signedness();
|
||||
auto b_sn = b_ty->get_integer_signedness();
|
||||
// Rules for signedness taken from "Usual arithmetic conversions" on
|
||||
// https://en.cppreference.com/w/c/language/conversion.
|
||||
if (a_sn == b_sn) {
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
} else if (a_sn == signedness::UNSIGNED) {
|
||||
return a_rank >= b_rank ? a_ty : b_ty;
|
||||
} else if (b_sn == signedness::UNSIGNED) {
|
||||
return b_rank >= a_rank ? b_ty : a_ty;
|
||||
} else {
|
||||
throw_unreachable("integer_promote");
|
||||
}
|
||||
}
|
||||
|
||||
enum class DivOrMod { NO, YES };
|
||||
@@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
|
||||
throw_unreachable("computation_type");
|
||||
// 4 ) both operands are integer and undergo
|
||||
// integer promotion
|
||||
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
return integer_promote(a_ty, b_ty);
|
||||
}
|
||||
|
||||
@@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int + int
|
||||
else if (input_scalar_ty->is_integer_ty())
|
||||
return builder->create_add(input, other);
|
||||
return throw_unreachable("add");
|
||||
throw_unreachable("add");
|
||||
}
|
||||
|
||||
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(input, other);
|
||||
return throw_unreachable("sub");
|
||||
throw_unreachable("sub");
|
||||
}
|
||||
|
||||
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(input, other);
|
||||
return throw_unreachable("mul");
|
||||
throw_unreachable("mul");
|
||||
}
|
||||
|
||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
|
||||
}
|
||||
// unreachable
|
||||
else
|
||||
return throw_unreachable("div");
|
||||
throw_unreachable("div");
|
||||
return builder->create_fdiv(input, other);
|
||||
}
|
||||
|
||||
@@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
||||
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
|
||||
input = dispatch::cast(input, ret_ty, builder);
|
||||
other = dispatch::cast(other, ret_ty, builder);
|
||||
return builder->create_sdiv(input, other);
|
||||
if (ret_ty->is_integer_signed()) {
|
||||
return builder->create_sdiv(input, other);
|
||||
} else {
|
||||
return builder->create_udiv(input, other);
|
||||
}
|
||||
}
|
||||
return throw_unreachable("floordiv");
|
||||
throw_unreachable("floordiv");
|
||||
}
|
||||
|
||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(input, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_srem(input, other);
|
||||
return throw_unreachable("mod");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_srem(input, other);
|
||||
} else {
|
||||
return builder->create_urem(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("mod");
|
||||
}
|
||||
|
||||
|
||||
@@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder
|
||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
||||
input = dispatch::cast(input, other_sca_ty, builder);
|
||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
||||
other = dispatch::cast(other, input_sca_ty, builder);
|
||||
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
|
||||
if (ret_sca_ty != input_sca_ty)
|
||||
input = dispatch::cast(input, ret_sca_ty, builder);
|
||||
if (ret_sca_ty != other_sca_ty)
|
||||
other = dispatch::cast(other, ret_sca_ty, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(input, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(input, other);
|
||||
return throw_unreachable("greater_than");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(input, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(input, other);
|
||||
return throw_unreachable("greater_equal");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(input, other);
|
||||
return throw_unreachable("less_than");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(input, other);
|
||||
return throw_unreachable("less_equal");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(input, other);
|
||||
return throw_unreachable("equal");
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpNE(input, other);
|
||||
return throw_unreachable("equal");
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
|
||||
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
|
||||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
|
||||
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
|
||||
return builder->create_int_cast(input, dst_ty, sign_extend);
|
||||
}
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
||||
if(dst_sca_ty->is_bool_ty())
|
||||
@@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
}
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
|
||||
if(src_sca_ty->is_bool_ty())
|
||||
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
|
||||
return builder->create_ui_to_fp(input, dst_ty);
|
||||
else
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
@@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
@@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
// direct call to atomic_min for integers
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
@@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
return throw_unreachable(name);
|
||||
throw_unreachable(name);
|
||||
}
|
||||
|
||||
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
|
@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
signedness type::get_integer_signedness() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
|
||||
|
||||
bool type::is_integer_signed() const {
|
||||
if (id_ != IntegerTyID) {
|
||||
throw std::logic_error("type is " + repr() + ", not integer");
|
||||
}
|
||||
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
|
||||
}
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
|
||||
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
|
||||
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
|
||||
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user