uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -40,6 +40,8 @@ public:
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_uint32(uint32_t val);
|
||||
value *get_uint64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
@@ -50,6 +52,10 @@ public:
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_uint8_ty();
|
||||
type *get_uint16_ty();
|
||||
type *get_uint32_ty();
|
||||
type *get_uint64_ty();
|
||||
type *get_half_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
|
@@ -28,6 +28,7 @@ public:
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
// Block types
|
||||
|
@@ -15,6 +15,8 @@ class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
enum class signedness { SIGNED, UNSIGNED };
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
@@ -58,6 +60,8 @@ public:
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
signedness get_integer_signedness() const;
|
||||
bool is_integer_signed() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
@@ -80,8 +84,9 @@ public:
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_integer_ty(unsigned bitwidth, signedness sn) {
|
||||
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
|
||||
}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
@@ -109,6 +114,10 @@ public:
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
static integer_type *get_uint8_ty(context &ctx);
|
||||
static integer_type *get_uint16_ty(context &ctx);
|
||||
static integer_type *get_uint32_ty(context &ctx);
|
||||
static integer_type *get_uint64_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
@@ -135,7 +144,7 @@ public:
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
||||
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
@@ -158,18 +167,21 @@ class integer_type: public type {
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
|
||||
integer_type(context &ctx, unsigned bitwidth, signedness sn)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
signedness get_signedness() const { return signedness_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
signedness signedness_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
|
@@ -51,9 +51,15 @@ value *builder::get_int1(bool val)
|
||||
value *builder::get_int32(int32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_uint32(uint32_t val)
|
||||
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(int64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_uint64(uint64_t val)
|
||||
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
@@ -84,6 +90,18 @@ type *builder::get_int32_ty()
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint8_ty()
|
||||
{ return type::get_uint8_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint16_ty()
|
||||
{ return type::get_uint16_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint32_ty()
|
||||
{ return type::get_uint32_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint64_ty()
|
||||
{ return type::get_uint64_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
|
@@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx)
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128){
|
||||
int1_ty(ctx, 1, signedness::SIGNED),
|
||||
int8_ty(ctx, 8, signedness::SIGNED),
|
||||
int16_ty(ctx, 16, signedness::SIGNED),
|
||||
int32_ty(ctx, 32, signedness::SIGNED),
|
||||
int64_ty(ctx, 64, signedness::SIGNED),
|
||||
int128_ty(ctx, 128, signedness::SIGNED),
|
||||
uint8_ty(ctx, 8, signedness::UNSIGNED),
|
||||
uint16_ty(ctx, 16, signedness::UNSIGNED),
|
||||
uint32_ty(ctx, 32, signedness::UNSIGNED),
|
||||
uint64_ty(ctx, 64, signedness::UNSIGNED){
|
||||
|
||||
}
|
||||
|
||||
|
@@ -1,14 +1,12 @@
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
|
||||
|
||||
ir::value* throw_unreachable(std::string key) {
|
||||
[[ noreturn ]] void throw_unreachable(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
|
||||
"This is likely a bug on our side.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
|
||||
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
||||
int a_rank = a_ty->get_integer_bitwidth();
|
||||
int b_rank = b_ty->get_integer_bitwidth();
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
auto a_sn = a_ty->get_integer_signedness();
|
||||
auto b_sn = b_ty->get_integer_signedness();
|
||||
// Rules for signedness taken from "Usual arithmetic conversions" on
|
||||
// https://en.cppreference.com/w/c/language/conversion.
|
||||
if (a_sn == b_sn) {
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
} else if (a_sn == signedness::UNSIGNED) {
|
||||
return a_rank >= b_rank ? a_ty : b_ty;
|
||||
} else if (b_sn == signedness::UNSIGNED) {
|
||||
return b_rank >= a_rank ? b_ty : a_ty;
|
||||
} else {
|
||||
throw_unreachable("integer_promote");
|
||||
}
|
||||
}
|
||||
|
||||
enum class DivOrMod { NO, YES };
|
||||
@@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
|
||||
throw_unreachable("computation_type");
|
||||
// 4 ) both operands are integer and undergo
|
||||
// integer promotion
|
||||
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
return integer_promote(a_ty, b_ty);
|
||||
}
|
||||
|
||||
@@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int + int
|
||||
else if (input_scalar_ty->is_integer_ty())
|
||||
return builder->create_add(input, other);
|
||||
return throw_unreachable("add");
|
||||
throw_unreachable("add");
|
||||
}
|
||||
|
||||
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(input, other);
|
||||
return throw_unreachable("sub");
|
||||
throw_unreachable("sub");
|
||||
}
|
||||
|
||||
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(input, other);
|
||||
return throw_unreachable("mul");
|
||||
throw_unreachable("mul");
|
||||
}
|
||||
|
||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
|
||||
}
|
||||
// unreachable
|
||||
else
|
||||
return throw_unreachable("div");
|
||||
throw_unreachable("div");
|
||||
return builder->create_fdiv(input, other);
|
||||
}
|
||||
|
||||
@@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
||||
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
|
||||
input = dispatch::cast(input, ret_ty, builder);
|
||||
other = dispatch::cast(other, ret_ty, builder);
|
||||
return builder->create_sdiv(input, other);
|
||||
if (ret_ty->is_integer_signed()) {
|
||||
return builder->create_sdiv(input, other);
|
||||
} else {
|
||||
return builder->create_udiv(input, other);
|
||||
}
|
||||
}
|
||||
return throw_unreachable("floordiv");
|
||||
throw_unreachable("floordiv");
|
||||
}
|
||||
|
||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(input, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_srem(input, other);
|
||||
return throw_unreachable("mod");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_srem(input, other);
|
||||
} else {
|
||||
return builder->create_urem(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("mod");
|
||||
}
|
||||
|
||||
|
||||
@@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder
|
||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
||||
input = dispatch::cast(input, other_sca_ty, builder);
|
||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
||||
other = dispatch::cast(other, input_sca_ty, builder);
|
||||
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
|
||||
if (ret_sca_ty != input_sca_ty)
|
||||
input = dispatch::cast(input, ret_sca_ty, builder);
|
||||
if (ret_sca_ty != other_sca_ty)
|
||||
other = dispatch::cast(other, ret_sca_ty, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(input, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGT(input, other);
|
||||
return throw_unreachable("greater_than");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(input, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSGE(input, other);
|
||||
return throw_unreachable("greater_equal");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLT(input, other);
|
||||
return throw_unreachable("less_than");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpSLE(input, other);
|
||||
return throw_unreachable("less_equal");
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(input, other);
|
||||
return throw_unreachable("equal");
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
@@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpNE(input, other);
|
||||
return throw_unreachable("equal");
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
||||
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
|
||||
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
|
||||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
|
||||
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
|
||||
return builder->create_int_cast(input, dst_ty, sign_extend);
|
||||
}
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
||||
if(dst_sca_ty->is_bool_ty())
|
||||
@@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
}
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
|
||||
if(src_sca_ty->is_bool_ty())
|
||||
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
|
||||
return builder->create_ui_to_fp(input, dst_ty);
|
||||
else
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
@@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
@@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty())
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
// direct call to atomic_min for integers
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
@@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
return throw_unreachable(name);
|
||||
throw_unreachable(name);
|
||||
}
|
||||
|
||||
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
|
@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
signedness type::get_integer_signedness() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
|
||||
|
||||
bool type::is_integer_signed() const {
|
||||
if (id_ != IntegerTyID) {
|
||||
throw std::logic_error("type is " + repr() + ", not integer");
|
||||
}
|
||||
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
|
||||
}
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
|
||||
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
|
||||
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
|
||||
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }
|
||||
|
||||
|
||||
|
||||
|
@@ -109,6 +109,24 @@ std::string pow2_divisor(long N){
|
||||
return "1";
|
||||
}
|
||||
|
||||
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||
// triton.language.dtype.
|
||||
std::string dtype_cache_key_part(const py::object& dtype) {
|
||||
if (py::hasattr(dtype, "cache_key_part")) {
|
||||
// Presumed to be a triton.language.dtype.
|
||||
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
|
||||
} else {
|
||||
// Remove 'torch.' prefix from repr of torch.dtype.
|
||||
py::object repr = py::repr(dtype);
|
||||
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||
const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
|
||||
throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
|
||||
}
|
||||
return std::string(repr_ptr + 6, repr_len - 6);
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||
@@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
cache_key += "1";
|
||||
continue;
|
||||
}
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
// int32, uint32, int64, and uint64 have different kernels
|
||||
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||
cache_key += "int32";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
}
|
||||
else{
|
||||
} else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
|
||||
cache_key += "uint32";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||
std::memcpy(params_ptr, &value, 4);
|
||||
params_ptr += 4;
|
||||
} else if (!overflow) {
|
||||
cache_key += "int64";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
if(overflow){
|
||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
std::memcpy(&value, &uvalue, 8);
|
||||
}
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
} else {
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::logic_error("An error occurred?");
|
||||
}
|
||||
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||
if (PyErr_Occurred()) {
|
||||
throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
|
||||
}
|
||||
cache_key += "uint64";
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||
params_ptr += 8;
|
||||
}
|
||||
if(!specialize)
|
||||
continue;
|
||||
@@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||
std::memcpy(params_ptr, &value, 8);
|
||||
params_ptr += 8;
|
||||
py::object dtype = arg.attr("dtype");
|
||||
py::object repr = py::repr(dtype);
|
||||
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
|
||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
||||
cache_key += std::string(start, len);
|
||||
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||
cache_key += "*";
|
||||
cache_key += "[multipleof(";
|
||||
cache_key += pow2_divisor(value);
|
||||
@@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
@@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
@@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -10,17 +10,20 @@ from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.code_gen import TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + float_dtypes
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
# ex.: "int64" -> 64
|
||||
return int(re.search(r'(\d+)$', dtype).group(1))
|
||||
|
||||
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
|
||||
"""
|
||||
Override `rs` if you're calling this function twice and don't want the same
|
||||
result for both calls.
|
||||
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
if rs is None:
|
||||
rs = RandomState(seed=17)
|
||||
dtype = getattr(np, dtype_str)
|
||||
if dtype_str in int_dtypes:
|
||||
if dtype_str in int_dtypes + uint_dtypes:
|
||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
|
||||
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||
x = rs.randint(low, high, shape, dtype=dtype)
|
||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||
return x
|
||||
elif dtype_str in float_dtypes:
|
||||
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||
|
||||
|
||||
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
|
||||
# For now, this always converts to a torch tensor, but when we add unsigned
|
||||
# integers, it will also support TensorWrapper, since torch doesn't have
|
||||
# unsigned support.
|
||||
return torch.tensor(x, device=device)
|
||||
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||
t = x.dtype.name
|
||||
if t in uint_dtypes:
|
||||
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||
x_signed = x.astype(getattr(np, signed_type_name))
|
||||
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||
else:
|
||||
return torch.tensor(x, device=device)
|
||||
|
||||
|
||||
def torch_dtype_name(dtype) -> str:
|
||||
if isinstance(dtype, triton.language.dtype):
|
||||
return dtype.name
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
# 'torch.int64' -> 'int64'
|
||||
m = re.match(r'^torch\.(\w+)$', str(dtype))
|
||||
return m.group(1)
|
||||
else:
|
||||
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if isinstance(x, TensorWrapper):
|
||||
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x.cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
||||
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
||||
operations on the two types should return. Returns None if the return value
|
||||
matches numpy. This is generally needed because Triton and pytorch return
|
||||
narrower floating point types than numpy in mixed operations.
|
||||
narrower floating point types than numpy in mixed operations, and because
|
||||
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
|
||||
numpy/pytorch do not.
|
||||
"""
|
||||
overrides = {
|
||||
('float16', 'int16'): np.float16,
|
||||
('float16', 'int32'): np.float16,
|
||||
('float16', 'int64'): np.float16,
|
||||
('float16', 'uint16'): np.float16,
|
||||
('float16', 'uint32'): np.float16,
|
||||
('float16', 'uint64'): np.float16,
|
||||
('int8', 'uint8'): np.uint8,
|
||||
('int8', 'uint16'): np.uint16,
|
||||
('int8', 'uint32'): np.uint32,
|
||||
('int8', 'uint64'): np.uint64,
|
||||
('int16', 'uint16'): np.uint16,
|
||||
('int16', 'uint32'): np.uint32,
|
||||
('int16', 'uint64'): np.uint64,
|
||||
('int32', 'uint32'): np.uint32,
|
||||
('int32', 'uint64'): np.uint64,
|
||||
('int64', 'uint64'): np.uint64,
|
||||
}
|
||||
key = (a, b) if a < b else (b, a)
|
||||
return overrides.get(key)
|
||||
|
||||
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||
SIZE = 128
|
||||
# define the kernel / launch-grid
|
||||
@triton.jit
|
||||
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
||||
# inputs
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
|
||||
if mode_x == 'nan':
|
||||
x[:] = float('nan')
|
||||
if mode_y == 'nan':
|
||||
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
('int64', 'float16'),
|
||||
('int64', 'float32'),
|
||||
('int64', 'float64'),
|
||||
('uint16', 'float16'),
|
||||
('uint16', 'float32'),
|
||||
('uint32', 'float16'),
|
||||
('uint32', 'float32'),
|
||||
('uint64', 'float16'),
|
||||
('uint64', 'float32'),
|
||||
('uint64', 'float64'),
|
||||
]
|
||||
|
||||
# ---------------
|
||||
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
||||
])
|
||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f' x {op} y'
|
||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
||||
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||
numpy_expr = 'np.fmod(x, y)'
|
||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
# are no native div or FRem operations on float16. Since we have to
|
||||
# convert anyway, we may as well take the accuracy bump.
|
||||
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
||||
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
elif (op in ('%', '/') and
|
||||
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||
)
|
||||
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||
# through to //, so we have to use a nonstandard expression to get a
|
||||
# reference result for //.
|
||||
expr = 'x // y'
|
||||
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test bitwise ops
|
||||
# ---------------
|
||||
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
])
|
||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
if 'float' in dtype_x + dtype_y:
|
||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||
# The CompilationError must have been caused by a C++ exception with this text.
|
||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||
else:
|
||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||
(dtype_x, dtype_y, op)
|
||||
for op in ['<<', '>>']
|
||||
for dtype_x in int_dtypes + uint_dtypes
|
||||
for dtype_y in int_dtypes + uint_dtypes
|
||||
])
|
||||
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
||||
dtype_z = f'uint{bw}'
|
||||
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
|
||||
])
|
||||
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
expr = f'x {op} y'
|
||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||
else:
|
||||
numpy_expr = None
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype_x, expr", [
|
||||
(dtype_x, ' -x') for dtype_x in dtypes
|
||||
] + [\
|
||||
] + [
|
||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||
])
|
||||
])
|
||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||
_test_unary(dtype_x, expr, device=device)
|
||||
|
||||
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', 'int32')
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
rank_x = expr.count(':')
|
||||
@@ -364,9 +457,9 @@ def test_tuples():
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'int32', mode), ('min', 'float32', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
if exact:
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
# ---------------
|
||||
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
('float32', 'bfloat16', False),
|
||||
('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
]
|
||||
)
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||
('float32', (1, 1024), 1)
|
||||
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
|
||||
])
|
||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||
# triton kernel
|
||||
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
|
||||
pass
|
||||
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
|
||||
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
)
|
||||
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
|
||||
if overflow:
|
||||
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||
kernel[(1, )](value, x)
|
||||
else:
|
||||
kernel[(1, )](value, x)
|
||||
|
@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
|
||||
N = x.numel()
|
||||
grid = (triton.cdiv(N, BLOCK),)
|
||||
kernel[grid](x, N, seed)
|
||||
assert all((x >= 0) & (x <= 1))
|
||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||
|
||||
# test normal PRNG
|
||||
|
@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
# print(op)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -503,6 +502,7 @@ class Binary:
|
||||
self.shared_mem = shared_mem
|
||||
self.num_warps = num_warps
|
||||
|
||||
|
||||
class LoadedBinary:
|
||||
def __init__(self, device: int, bin: Binary):
|
||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||
@@ -571,24 +571,33 @@ class Kernel:
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
triton.language.uint8: 'u8',
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if abs(obj) <= 0xffffffff:
|
||||
return 'I'
|
||||
return 'L'
|
||||
if -2**31 <= obj < 2**31:
|
||||
return 'i32'
|
||||
elif 2**31 <= obj < 2**32:
|
||||
return 'u32'
|
||||
elif -2**63 <= obj < 2**63:
|
||||
return 'i64'
|
||||
elif 2**63 <= obj < 2**64:
|
||||
return 'u64'
|
||||
else:
|
||||
raise ValueError(f'integer overflow representing {obj}')
|
||||
if isinstance(obj, float):
|
||||
return 'f'
|
||||
if isinstance(obj, bool):
|
||||
return 'B'
|
||||
if isinstance(obj, str):
|
||||
return 'str'
|
||||
assert False
|
||||
|
||||
|
||||
raise NotImplementedError(f'could not compute type name for {obj}')
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(context, obj):
|
||||
@@ -607,6 +616,10 @@ class Kernel:
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
@@ -1165,4 +1178,15 @@ class TensorWrapper:
|
||||
|
||||
|
||||
def reinterpret(tensor, dtype):
|
||||
return TensorWrapper(tensor, dtype)
|
||||
if isinstance(tensor, TensorWrapper):
|
||||
if dtype == tensor.base.dtype:
|
||||
# Reinterpreting to the original interpretation; return the base.
|
||||
return tensor.base
|
||||
else:
|
||||
# Reinterpreting a wrapped tensor to a different type.
|
||||
return TensorWrapper(tensor.base, dtype)
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
# A new wrapper is needed around an unwrapped tensor.
|
||||
return TensorWrapper(tensor, dtype)
|
||||
else:
|
||||
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||
|
@@ -9,9 +9,16 @@ def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
if x.__abs__() <= 2**31:
|
||||
if -2**31 <= x < 2**31:
|
||||
return builder.get_int32(x)
|
||||
return builder.get_int64(x)
|
||||
elif 2**31 <= x < 2**32:
|
||||
return builder.get_uint32(x)
|
||||
elif -2**63 <= x < 2**63:
|
||||
return builder.get_int64(x)
|
||||
elif 2**63 <= x < 2**64:
|
||||
return builder.get_uint64(x)
|
||||
else:
|
||||
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
elif isinstance(x, constexpr):
|
||||
@@ -83,6 +90,14 @@ class dtype:
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def cache_key_part(self) -> str:
|
||||
"""See cache_key_part() in triton.cc."""
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f'triton.language.{self.name}'
|
||||
|
||||
|
||||
class pointer_dtype:
|
||||
def __init__(self, element_ty):
|
||||
@@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
int32 = dtype(ir.type.get_int32)
|
||||
int64 = dtype(ir.type.get_int64)
|
||||
uint8 = dtype(ir.type.get_uint8)
|
||||
uint16 = dtype(ir.type.get_uint16)
|
||||
uint32 = dtype(ir.type.get_uint32)
|
||||
uint64 = dtype(ir.type.get_uint64)
|
||||
float8 = dtype(ir.type.get_fp8)
|
||||
float16 = dtype(ir.type.get_fp16)
|
||||
bfloat16 = dtype(ir.type.get_bf16)
|
||||
@@ -120,6 +139,10 @@ class block:
|
||||
if ir_type.is_int16(): return int16
|
||||
if ir_type.is_int32(): return int32
|
||||
if ir_type.is_int64(): return int64
|
||||
if ir_type.is_uint8(): return uint8
|
||||
if ir_type.is_uint16(): return uint16
|
||||
if ir_type.is_uint32(): return uint32
|
||||
if ir_type.is_uint64(): return uint64
|
||||
if ir_type.is_fp8(): return float8
|
||||
if ir_type.is_fp16(): return float16
|
||||
if ir_type.is_bf16(): return bfloat16
|
||||
|
Reference in New Issue
Block a user