uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types. Also: - Add tests for the `//`, `<<`, and `>>` operators. - Change `TensorWrapper` to unwrap objects when the resulting object would be simpler. - Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
committed by
GitHub
parent
d8db0308cb
commit
0ab9d67bad
@@ -40,6 +40,8 @@ public:
|
|||||||
value *get_int1(bool val);
|
value *get_int1(bool val);
|
||||||
value *get_int32(int32_t val);
|
value *get_int32(int32_t val);
|
||||||
value *get_int64(int64_t val);
|
value *get_int64(int64_t val);
|
||||||
|
value *get_uint32(uint32_t val);
|
||||||
|
value *get_uint64(uint64_t val);
|
||||||
value *get_float16(float val);
|
value *get_float16(float val);
|
||||||
value *get_float32(float val);
|
value *get_float32(float val);
|
||||||
value *get_range(int32_t lo, int32_t hi);
|
value *get_range(int32_t lo, int32_t hi);
|
||||||
@@ -50,6 +52,10 @@ public:
|
|||||||
type *get_int16_ty();
|
type *get_int16_ty();
|
||||||
type *get_int32_ty();
|
type *get_int32_ty();
|
||||||
type *get_int64_ty();
|
type *get_int64_ty();
|
||||||
|
type *get_uint8_ty();
|
||||||
|
type *get_uint16_ty();
|
||||||
|
type *get_uint32_ty();
|
||||||
|
type *get_uint64_ty();
|
||||||
type *get_half_ty();
|
type *get_half_ty();
|
||||||
type *get_float_ty();
|
type *get_float_ty();
|
||||||
type *get_double_ty();
|
type *get_double_ty();
|
||||||
|
@@ -28,6 +28,7 @@ public:
|
|||||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||||
// integer types
|
// integer types
|
||||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||||
|
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
|
||||||
// Pointer types
|
// Pointer types
|
||||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||||
// Block types
|
// Block types
|
||||||
|
@@ -15,6 +15,8 @@ class value;
|
|||||||
class integer_type;
|
class integer_type;
|
||||||
class constant_int;
|
class constant_int;
|
||||||
|
|
||||||
|
enum class signedness { SIGNED, UNSIGNED };
|
||||||
|
|
||||||
/* Type */
|
/* Type */
|
||||||
class type {
|
class type {
|
||||||
public:
|
public:
|
||||||
@@ -58,6 +60,8 @@ public:
|
|||||||
// type attributes
|
// type attributes
|
||||||
unsigned get_fp_mantissa_width() const;
|
unsigned get_fp_mantissa_width() const;
|
||||||
unsigned get_integer_bitwidth() const;
|
unsigned get_integer_bitwidth() const;
|
||||||
|
signedness get_integer_signedness() const;
|
||||||
|
bool is_integer_signed() const;
|
||||||
unsigned get_tile_bitwidth() const;
|
unsigned get_tile_bitwidth() const;
|
||||||
unsigned get_primitive_size_in_bits() const;
|
unsigned get_primitive_size_in_bits() const;
|
||||||
type *get_scalar_ty() const;
|
type *get_scalar_ty() const;
|
||||||
@@ -80,8 +84,9 @@ public:
|
|||||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
bool is_integer_ty(unsigned bitwidth, signedness sn) {
|
||||||
get_integer_bitwidth() == bitwidth;}
|
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
|
||||||
|
}
|
||||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||||
@@ -109,6 +114,10 @@ public:
|
|||||||
static integer_type *get_int32_ty(context &ctx);
|
static integer_type *get_int32_ty(context &ctx);
|
||||||
static integer_type *get_int64_ty(context &ctx);
|
static integer_type *get_int64_ty(context &ctx);
|
||||||
static integer_type *get_int128_ty(context &ctx);
|
static integer_type *get_int128_ty(context &ctx);
|
||||||
|
static integer_type *get_uint8_ty(context &ctx);
|
||||||
|
static integer_type *get_uint16_ty(context &ctx);
|
||||||
|
static integer_type *get_uint32_ty(context &ctx);
|
||||||
|
static integer_type *get_uint64_ty(context &ctx);
|
||||||
|
|
||||||
// repr
|
// repr
|
||||||
std::string tile_repr() const {
|
std::string tile_repr() const {
|
||||||
@@ -135,7 +144,7 @@ public:
|
|||||||
case LabelTyID: return "label";
|
case LabelTyID: return "label";
|
||||||
case MetadataTyID: return "md";
|
case MetadataTyID: return "md";
|
||||||
case TokenTyID: return "tok";
|
case TokenTyID: return "tok";
|
||||||
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
|
||||||
case FunctionTyID: return "fn";
|
case FunctionTyID: return "fn";
|
||||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||||
case StructTyID: return "struct";
|
case StructTyID: return "struct";
|
||||||
@@ -158,18 +167,21 @@ class integer_type: public type {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// constructors
|
// constructors
|
||||||
integer_type(context &ctx, unsigned bitwidth)
|
integer_type(context &ctx, unsigned bitwidth, signedness sn)
|
||||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
|
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// accessors
|
// accessors
|
||||||
unsigned get_bitwidth() const { return bitwidth_; }
|
unsigned get_bitwidth() const { return bitwidth_; }
|
||||||
|
|
||||||
|
signedness get_signedness() const { return signedness_; }
|
||||||
|
|
||||||
// factory methods
|
// factory methods
|
||||||
static integer_type* get(context &ctx, unsigned width);
|
static integer_type* get(context &ctx, unsigned width);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned bitwidth_;
|
unsigned bitwidth_;
|
||||||
|
signedness signedness_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class composite_type: public type{
|
class composite_type: public type{
|
||||||
|
@@ -51,9 +51,15 @@ value *builder::get_int1(bool val)
|
|||||||
value *builder::get_int32(int32_t val)
|
value *builder::get_int32(int32_t val)
|
||||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||||
|
|
||||||
|
value *builder::get_uint32(uint32_t val)
|
||||||
|
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
|
||||||
|
|
||||||
value *builder::get_int64(int64_t val)
|
value *builder::get_int64(int64_t val)
|
||||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||||
|
|
||||||
|
value *builder::get_uint64(uint64_t val)
|
||||||
|
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
|
||||||
|
|
||||||
value *builder::get_float16(float val)
|
value *builder::get_float16(float val)
|
||||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||||
|
|
||||||
@@ -84,6 +90,18 @@ type *builder::get_int32_ty()
|
|||||||
type *builder::get_int64_ty()
|
type *builder::get_int64_ty()
|
||||||
{ return type::get_int64_ty(ctx_); }
|
{ return type::get_int64_ty(ctx_); }
|
||||||
|
|
||||||
|
type *builder::get_uint8_ty()
|
||||||
|
{ return type::get_uint8_ty(ctx_); }
|
||||||
|
|
||||||
|
type *builder::get_uint16_ty()
|
||||||
|
{ return type::get_uint16_ty(ctx_); }
|
||||||
|
|
||||||
|
type *builder::get_uint32_ty()
|
||||||
|
{ return type::get_uint32_ty(ctx_); }
|
||||||
|
|
||||||
|
type *builder::get_uint64_ty()
|
||||||
|
{ return type::get_uint64_ty(ctx_); }
|
||||||
|
|
||||||
type *builder::get_half_ty()
|
type *builder::get_half_ty()
|
||||||
{ return type::get_fp16_ty(ctx_); }
|
{ return type::get_fp16_ty(ctx_); }
|
||||||
|
|
||||||
|
@@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx)
|
|||||||
fp32_ty(ctx, type::FP32TyID),
|
fp32_ty(ctx, type::FP32TyID),
|
||||||
fp64_ty(ctx, type::FP64TyID),
|
fp64_ty(ctx, type::FP64TyID),
|
||||||
// integers
|
// integers
|
||||||
int1_ty(ctx, 1),
|
int1_ty(ctx, 1, signedness::SIGNED),
|
||||||
int8_ty(ctx, 8),
|
int8_ty(ctx, 8, signedness::SIGNED),
|
||||||
int16_ty(ctx, 16),
|
int16_ty(ctx, 16, signedness::SIGNED),
|
||||||
int32_ty(ctx, 32),
|
int32_ty(ctx, 32, signedness::SIGNED),
|
||||||
int64_ty(ctx, 64),
|
int64_ty(ctx, 64, signedness::SIGNED),
|
||||||
int128_ty(ctx, 128){
|
int128_ty(ctx, 128, signedness::SIGNED),
|
||||||
|
uint8_ty(ctx, 8, signedness::UNSIGNED),
|
||||||
|
uint16_ty(ctx, 16, signedness::UNSIGNED),
|
||||||
|
uint32_ty(ctx, 32, signedness::UNSIGNED),
|
||||||
|
uint64_ty(ctx, 64, signedness::UNSIGNED){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,14 +1,12 @@
|
|||||||
#include "triton/ir/dispatch.h"
|
#include "triton/ir/dispatch.h"
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace triton{
|
namespace triton {
|
||||||
namespace ir{
|
namespace ir {
|
||||||
|
|
||||||
|
|
||||||
ir::value* throw_unreachable(std::string key) {
|
[[ noreturn ]] void throw_unreachable(std::string key) {
|
||||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
|
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
|
||||||
"This is likely a bug on our side.");
|
"This is likely a bug on our side.");
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
|
|||||||
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
||||||
int a_rank = a_ty->get_integer_bitwidth();
|
int a_rank = a_ty->get_integer_bitwidth();
|
||||||
int b_rank = b_ty->get_integer_bitwidth();
|
int b_rank = b_ty->get_integer_bitwidth();
|
||||||
|
auto a_sn = a_ty->get_integer_signedness();
|
||||||
|
auto b_sn = b_ty->get_integer_signedness();
|
||||||
|
// Rules for signedness taken from "Usual arithmetic conversions" on
|
||||||
|
// https://en.cppreference.com/w/c/language/conversion.
|
||||||
|
if (a_sn == b_sn) {
|
||||||
return a_rank > b_rank ? a_ty : b_ty;
|
return a_rank > b_rank ? a_ty : b_ty;
|
||||||
|
} else if (a_sn == signedness::UNSIGNED) {
|
||||||
|
return a_rank >= b_rank ? a_ty : b_ty;
|
||||||
|
} else if (b_sn == signedness::UNSIGNED) {
|
||||||
|
return b_rank >= a_rank ? b_ty : a_ty;
|
||||||
|
} else {
|
||||||
|
throw_unreachable("integer_promote");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class DivOrMod { NO, YES };
|
enum class DivOrMod { NO, YES };
|
||||||
@@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
|
|||||||
throw_unreachable("computation_type");
|
throw_unreachable("computation_type");
|
||||||
// 4 ) both operands are integer and undergo
|
// 4 ) both operands are integer and undergo
|
||||||
// integer promotion
|
// integer promotion
|
||||||
|
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
|
||||||
|
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||||
|
}
|
||||||
return integer_promote(a_ty, b_ty);
|
return integer_promote(a_ty, b_ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde
|
|||||||
// int + int
|
// int + int
|
||||||
else if (input_scalar_ty->is_integer_ty())
|
else if (input_scalar_ty->is_integer_ty())
|
||||||
return builder->create_add(input, other);
|
return builder->create_add(input, other);
|
||||||
return throw_unreachable("add");
|
throw_unreachable("add");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde
|
|||||||
// int + int
|
// int + int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_sub(input, other);
|
return builder->create_sub(input, other);
|
||||||
return throw_unreachable("sub");
|
throw_unreachable("sub");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
|
|||||||
// int * int
|
// int * int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_mul(input, other);
|
return builder->create_mul(input, other);
|
||||||
return throw_unreachable("mul");
|
throw_unreachable("mul");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
|
|||||||
}
|
}
|
||||||
// unreachable
|
// unreachable
|
||||||
else
|
else
|
||||||
return throw_unreachable("div");
|
throw_unreachable("div");
|
||||||
return builder->create_fdiv(input, other);
|
return builder->create_fdiv(input, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
|
|||||||
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
|
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
|
||||||
input = dispatch::cast(input, ret_ty, builder);
|
input = dispatch::cast(input, ret_ty, builder);
|
||||||
other = dispatch::cast(other, ret_ty, builder);
|
other = dispatch::cast(other, ret_ty, builder);
|
||||||
|
if (ret_ty->is_integer_signed()) {
|
||||||
return builder->create_sdiv(input, other);
|
return builder->create_sdiv(input, other);
|
||||||
|
} else {
|
||||||
|
return builder->create_udiv(input, other);
|
||||||
}
|
}
|
||||||
return throw_unreachable("floordiv");
|
}
|
||||||
|
throw_unreachable("floordiv");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||||
|
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||||
// float % int
|
// float % int
|
||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_frem(input, other);
|
return builder->create_frem(input, other);
|
||||||
// int % int
|
// int % int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty()) {
|
||||||
|
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
|
||||||
|
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||||
|
}
|
||||||
|
if (scalar_ty->is_integer_signed()) {
|
||||||
return builder->create_srem(input, other);
|
return builder->create_srem(input, other);
|
||||||
return throw_unreachable("mod");
|
} else {
|
||||||
|
return builder->create_urem(input, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw_unreachable("mod");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder
|
|||||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||||
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
|
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
|
||||||
input = dispatch::cast(input, other_sca_ty, builder);
|
if (ret_sca_ty != input_sca_ty)
|
||||||
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
|
input = dispatch::cast(input, ret_sca_ty, builder);
|
||||||
other = dispatch::cast(other, input_sca_ty, builder);
|
if (ret_sca_ty != other_sca_ty)
|
||||||
|
other = dispatch::cast(other, ret_sca_ty, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
|
|||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpOGT(input, other);
|
return builder->create_fcmpOGT(input, other);
|
||||||
// int > int
|
// int > int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty()) {
|
||||||
|
if (scalar_ty->is_integer_signed()) {
|
||||||
return builder->create_icmpSGT(input, other);
|
return builder->create_icmpSGT(input, other);
|
||||||
return throw_unreachable("greater_than");
|
} else {
|
||||||
|
return builder->create_icmpUGT(input, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw_unreachable("greater_than");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
|
|||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpOGE(input, other);
|
return builder->create_fcmpOGE(input, other);
|
||||||
// int >= int
|
// int >= int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty()) {
|
||||||
|
if (scalar_ty->is_integer_signed()) {
|
||||||
return builder->create_icmpSGE(input, other);
|
return builder->create_icmpSGE(input, other);
|
||||||
return throw_unreachable("greater_equal");
|
} else {
|
||||||
|
return builder->create_icmpUGE(input, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw_unreachable("greater_equal");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
|
|||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpOLT(input, other);
|
return builder->create_fcmpOLT(input, other);
|
||||||
// int < int
|
// int < int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty()) {
|
||||||
|
if (scalar_ty->is_integer_signed()) {
|
||||||
return builder->create_icmpSLT(input, other);
|
return builder->create_icmpSLT(input, other);
|
||||||
return throw_unreachable("less_than");
|
} else {
|
||||||
|
return builder->create_icmpULT(input, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw_unreachable("less_than");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
|
|||||||
if (scalar_ty->is_floating_point_ty())
|
if (scalar_ty->is_floating_point_ty())
|
||||||
return builder->create_fcmpOLE(input, other);
|
return builder->create_fcmpOLE(input, other);
|
||||||
// int < int
|
// int < int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty()) {
|
||||||
|
if (scalar_ty->is_integer_signed()) {
|
||||||
return builder->create_icmpSLE(input, other);
|
return builder->create_icmpSLE(input, other);
|
||||||
return throw_unreachable("less_equal");
|
} else {
|
||||||
|
return builder->create_icmpULE(input, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw_unreachable("less_equal");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
|
|||||||
// int == int
|
// int == int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpEQ(input, other);
|
return builder->create_icmpEQ(input, other);
|
||||||
return throw_unreachable("equal");
|
throw_unreachable("equal");
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||||
@@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
|
|||||||
// int == int
|
// int == int
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_icmpNE(input, other);
|
return builder->create_icmpNE(input, other);
|
||||||
return throw_unreachable("equal");
|
throw_unreachable("equal");
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
return builder->create_fp_ext(input, dst_ty);
|
return builder->create_fp_ext(input, dst_ty);
|
||||||
// Int cast
|
// Int cast
|
||||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||||
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
|
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
|
||||||
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
|
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
|
||||||
|
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
|
||||||
|
return builder->create_int_cast(input, dst_ty, sign_extend);
|
||||||
|
}
|
||||||
// Float -> Int
|
// Float -> Int
|
||||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
||||||
if(dst_sca_ty->is_bool_ty())
|
if(dst_sca_ty->is_bool_ty())
|
||||||
@@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
}
|
}
|
||||||
// int -> Float
|
// int -> Float
|
||||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
|
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
|
||||||
if(src_sca_ty->is_bool_ty())
|
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
|
||||||
return builder->create_ui_to_fp(input, dst_ty);
|
return builder->create_ui_to_fp(input, dst_ty);
|
||||||
else
|
else
|
||||||
return builder->create_si_to_fp(input, dst_ty);
|
return builder->create_si_to_fp(input, dst_ty);
|
||||||
@@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
|
|||||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||||
return builder->create_icmpNE(input, other);
|
return builder->create_icmpNE(input, other);
|
||||||
}
|
}
|
||||||
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
|||||||
atom_red_typechecking(ptr, val, mask, builder);
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||||
// direct call to atomic_max for integers
|
// direct call to atomic_max for integers
|
||||||
if(sca_ty->is_integer_ty())
|
if(sca_ty->is_integer_ty()) {
|
||||||
|
if (sca_ty->is_integer_signed()) {
|
||||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||||
|
} else {
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
|
||||||
|
}
|
||||||
|
}
|
||||||
// for float
|
// for float
|
||||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||||
@@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
|
|||||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||||
atom_red_typechecking(ptr, val, mask, builder);
|
atom_red_typechecking(ptr, val, mask, builder);
|
||||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||||
// direct call to atomic_max for integers
|
// direct call to atomic_min for integers
|
||||||
if(sca_ty->is_integer_ty())
|
if(sca_ty->is_integer_ty()) {
|
||||||
|
if (sca_ty->is_integer_signed()) {
|
||||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||||
|
} else {
|
||||||
|
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
|
||||||
|
}
|
||||||
|
}
|
||||||
// for float
|
// for float
|
||||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||||
@@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
|
|||||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||||
else if (scalar_ty->is_integer_ty())
|
else if (scalar_ty->is_integer_ty())
|
||||||
return builder->create_reduce(input, INT_OP, axis);
|
return builder->create_reduce(input, INT_OP, axis);
|
||||||
return throw_unreachable(name);
|
throw_unreachable(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||||
|
@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
|
|||||||
unsigned type::get_integer_bitwidth() const
|
unsigned type::get_integer_bitwidth() const
|
||||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||||
|
|
||||||
|
signedness type::get_integer_signedness() const
|
||||||
|
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
|
||||||
|
|
||||||
|
bool type::is_integer_signed() const {
|
||||||
|
if (id_ != IntegerTyID) {
|
||||||
|
throw std::logic_error("type is " + repr() + ", not integer");
|
||||||
|
}
|
||||||
|
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned type::get_tile_bitwidth() const
|
unsigned type::get_tile_bitwidth() const
|
||||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||||
|
|
||||||
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
|||||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||||
|
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
|
||||||
|
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
|
||||||
|
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
|
||||||
|
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -109,6 +109,24 @@ std::string pow2_divisor(long N){
|
|||||||
return "1";
|
return "1";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns something like "int16", whether dtype is a torch.dtype or
|
||||||
|
// triton.language.dtype.
|
||||||
|
std::string dtype_cache_key_part(const py::object& dtype) {
|
||||||
|
if (py::hasattr(dtype, "cache_key_part")) {
|
||||||
|
// Presumed to be a triton.language.dtype.
|
||||||
|
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
|
||||||
|
} else {
|
||||||
|
// Remove 'torch.' prefix from repr of torch.dtype.
|
||||||
|
py::object repr = py::repr(dtype);
|
||||||
|
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
|
||||||
|
const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
|
||||||
|
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
|
||||||
|
throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
|
||||||
|
}
|
||||||
|
return std::string(repr_ptr + 6, repr_len - 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Launch
|
// Launch
|
||||||
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
|
||||||
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
|
||||||
@@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
cache_key += "1";
|
cache_key += "1";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// long and int have different kernels
|
// int32, uint32, int64, and uint64 have different kernels
|
||||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
|
||||||
cache_key += "int32";
|
cache_key += "int32";
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
std::memcpy(params_ptr, &value, 4);
|
std::memcpy(params_ptr, &value, 4);
|
||||||
params_ptr += 4;
|
params_ptr += 4;
|
||||||
}
|
} else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
|
||||||
else{
|
cache_key += "uint32";
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
|
||||||
|
std::memcpy(params_ptr, &value, 4);
|
||||||
|
params_ptr += 4;
|
||||||
|
} else if (!overflow) {
|
||||||
cache_key += "int64";
|
cache_key += "int64";
|
||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
if(overflow){
|
|
||||||
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
|
|
||||||
std::memcpy(&value, &uvalue, 8);
|
|
||||||
}
|
|
||||||
std::memcpy(params_ptr, &value, 8);
|
std::memcpy(params_ptr, &value, 8);
|
||||||
params_ptr += 8;
|
params_ptr += 8;
|
||||||
|
} else {
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw std::logic_error("An error occurred?");
|
||||||
|
}
|
||||||
|
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
|
||||||
|
}
|
||||||
|
cache_key += "uint64";
|
||||||
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
|
std::memcpy(params_ptr, &unsigned_value, 8);
|
||||||
|
params_ptr += 8;
|
||||||
}
|
}
|
||||||
if(!specialize)
|
if(!specialize)
|
||||||
continue;
|
continue;
|
||||||
@@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
|
||||||
std::memcpy(params_ptr, &value, 8);
|
std::memcpy(params_ptr, &value, 8);
|
||||||
params_ptr += 8;
|
params_ptr += 8;
|
||||||
py::object dtype = arg.attr("dtype");
|
cache_key += dtype_cache_key_part(arg.attr("dtype"));
|
||||||
py::object repr = py::repr(dtype);
|
|
||||||
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
|
|
||||||
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
|
|
||||||
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
|
|
||||||
cache_key += std::string(start, len);
|
|
||||||
cache_key += "*";
|
cache_key += "*";
|
||||||
cache_key += "[multipleof(";
|
cache_key += "[multipleof(";
|
||||||
cache_key += pow2_divisor(value);
|
cache_key += pow2_divisor(value);
|
||||||
@@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||||
|
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||||
|
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||||
|
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||||
|
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||||
|
|
||||||
.def("is_void", &ir::type::is_void_ty)
|
.def("is_void", &ir::type::is_void_ty)
|
||||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||||
@@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||||
|
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||||
|
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||||
|
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||||
|
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||||
|
|
||||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||||
@@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||||
|
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||||
|
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,17 +10,20 @@ from numpy.random import RandomState
|
|||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from triton.code_gen import TensorWrapper, reinterpret
|
||||||
|
|
||||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||||
|
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||||
float_dtypes = ['float16', 'float32', 'float64']
|
float_dtypes = ['float16', 'float32', 'float64']
|
||||||
dtypes = int_dtypes + float_dtypes
|
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||||
|
|
||||||
|
|
||||||
def _bitwidth(dtype: str) -> int:
|
def _bitwidth(dtype: str) -> int:
|
||||||
# ex.: "int64" -> 64
|
# ex.: "int64" -> 64
|
||||||
return int(re.search(r'(\d+)$', dtype).group(1))
|
return int(re.search(r'(\d+)$', dtype).group(1))
|
||||||
|
|
||||||
|
|
||||||
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
|
||||||
"""
|
"""
|
||||||
Override `rs` if you're calling this function twice and don't want the same
|
Override `rs` if you're calling this function twice and don't want the same
|
||||||
result for both calls.
|
result for both calls.
|
||||||
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
|||||||
if rs is None:
|
if rs is None:
|
||||||
rs = RandomState(seed=17)
|
rs = RandomState(seed=17)
|
||||||
dtype = getattr(np, dtype_str)
|
dtype = getattr(np, dtype_str)
|
||||||
if dtype_str in int_dtypes:
|
if dtype_str in int_dtypes + uint_dtypes:
|
||||||
iinfo = np.iinfo(getattr(np, dtype_str))
|
iinfo = np.iinfo(getattr(np, dtype_str))
|
||||||
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
|
low = iinfo.min if low is None else max(low, iinfo.min)
|
||||||
|
high = iinfo.max if high is None else min(high, iinfo.max)
|
||||||
|
x = rs.randint(low, high, shape, dtype=dtype)
|
||||||
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
|
||||||
return x
|
return x
|
||||||
elif dtype_str in float_dtypes:
|
elif dtype_str in float_dtypes:
|
||||||
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
|
|||||||
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
raise RuntimeError(f'Unknown dtype {dtype_str}')
|
||||||
|
|
||||||
|
|
||||||
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
|
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
|
||||||
# For now, this always converts to a torch tensor, but when we add unsigned
|
t = x.dtype.name
|
||||||
# integers, it will also support TensorWrapper, since torch doesn't have
|
if t in uint_dtypes:
|
||||||
# unsigned support.
|
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
|
||||||
|
x_signed = x.astype(getattr(np, signed_type_name))
|
||||||
|
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
|
||||||
|
else:
|
||||||
return torch.tensor(x, device=device)
|
return torch.tensor(x, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_dtype_name(dtype) -> str:
|
||||||
|
if isinstance(dtype, triton.language.dtype):
|
||||||
|
return dtype.name
|
||||||
|
elif isinstance(dtype, torch.dtype):
|
||||||
|
# 'torch.int64' -> 'int64'
|
||||||
|
m = re.match(r'^torch\.(\w+)$', str(dtype))
|
||||||
|
return m.group(1)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
|
||||||
|
|
||||||
|
|
||||||
def to_numpy(x):
|
def to_numpy(x):
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, TensorWrapper):
|
||||||
|
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
|
||||||
|
elif isinstance(x, torch.Tensor):
|
||||||
return x.cpu().numpy()
|
return x.cpu().numpy()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
raise ValueError(f"Not a triton-compatible tensor: {x}")
|
||||||
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
|
|||||||
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
Given two dtype strings, returns the numpy dtype Triton thinks binary
|
||||||
operations on the two types should return. Returns None if the return value
|
operations on the two types should return. Returns None if the return value
|
||||||
matches numpy. This is generally needed because Triton and pytorch return
|
matches numpy. This is generally needed because Triton and pytorch return
|
||||||
narrower floating point types than numpy in mixed operations.
|
narrower floating point types than numpy in mixed operations, and because
|
||||||
|
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
|
||||||
|
numpy/pytorch do not.
|
||||||
"""
|
"""
|
||||||
overrides = {
|
overrides = {
|
||||||
('float16', 'int16'): np.float16,
|
('float16', 'int16'): np.float16,
|
||||||
('float16', 'int32'): np.float16,
|
('float16', 'int32'): np.float16,
|
||||||
('float16', 'int64'): np.float16,
|
('float16', 'int64'): np.float16,
|
||||||
|
('float16', 'uint16'): np.float16,
|
||||||
|
('float16', 'uint32'): np.float16,
|
||||||
|
('float16', 'uint64'): np.float16,
|
||||||
|
('int8', 'uint8'): np.uint8,
|
||||||
|
('int8', 'uint16'): np.uint16,
|
||||||
|
('int8', 'uint32'): np.uint32,
|
||||||
|
('int8', 'uint64'): np.uint64,
|
||||||
|
('int16', 'uint16'): np.uint16,
|
||||||
|
('int16', 'uint32'): np.uint32,
|
||||||
|
('int16', 'uint64'): np.uint64,
|
||||||
|
('int32', 'uint32'): np.uint32,
|
||||||
|
('int32', 'uint64'): np.uint64,
|
||||||
|
('int64', 'uint64'): np.uint64,
|
||||||
}
|
}
|
||||||
key = (a, b) if a < b else (b, a)
|
key = (a, b) if a < b else (b, a)
|
||||||
return overrides.get(key)
|
return overrides.get(key)
|
||||||
|
|
||||||
|
|
||||||
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
|
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
|
||||||
SIZE = 128
|
SIZE = 128
|
||||||
# define the kernel / launch-grid
|
# define the kernel / launch-grid
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
|
|||||||
# inputs
|
# inputs
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
|
||||||
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
|
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
|
||||||
if mode_x == 'nan':
|
if mode_x == 'nan':
|
||||||
x[:] = float('nan')
|
x[:] = float('nan')
|
||||||
if mode_y == 'nan':
|
if mode_y == 'nan':
|
||||||
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|||||||
('int64', 'float16'),
|
('int64', 'float16'),
|
||||||
('int64', 'float32'),
|
('int64', 'float32'),
|
||||||
('int64', 'float64'),
|
('int64', 'float64'),
|
||||||
|
('uint16', 'float16'),
|
||||||
|
('uint16', 'float32'),
|
||||||
|
('uint32', 'float16'),
|
||||||
|
('uint32', 'float32'),
|
||||||
|
('uint64', 'float16'),
|
||||||
|
('uint64', 'float32'),
|
||||||
|
('uint64', 'float64'),
|
||||||
]
|
]
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
|
|||||||
])
|
])
|
||||||
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
expr = f' x {op} y'
|
expr = f' x {op} y'
|
||||||
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
|
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
|
||||||
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
|
||||||
numpy_expr = 'np.fmod(x, y)'
|
numpy_expr = 'np.fmod(x, y)'
|
||||||
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
|
||||||
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
# are no native div or FRem operations on float16. Since we have to
|
# are no native div or FRem operations on float16. Since we have to
|
||||||
# convert anyway, we may as well take the accuracy bump.
|
# convert anyway, we may as well take the accuracy bump.
|
||||||
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
|
||||||
|
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||||
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||||
else:
|
else:
|
||||||
numpy_expr = None
|
numpy_expr = None
|
||||||
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
|
||||||
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
with pytest.raises(AssertionError, match='Not equal to tolerance'):
|
||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
elif (op in ('%', '/') and
|
||||||
|
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
|
||||||
|
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
|
||||||
|
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
|
||||||
else:
|
else:
|
||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_x, dtype_y",
|
||||||
|
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
|
||||||
|
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
|
||||||
|
)
|
||||||
|
def test_floordiv(dtype_x, dtype_y, device='cuda'):
|
||||||
|
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
|
||||||
|
# through to //, so we have to use a nonstandard expression to get a
|
||||||
|
# reference result for //.
|
||||||
|
expr = 'x // y'
|
||||||
|
numpy_expr = '((x - np.fmod(x, y)) / y)'
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test bitwise ops
|
# test bitwise ops
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
|
|||||||
])
|
])
|
||||||
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||||
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||||
|
else:
|
||||||
|
numpy_expr = None
|
||||||
if 'float' in dtype_x + dtype_y:
|
if 'float' in dtype_x + dtype_y:
|
||||||
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
|
||||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
|
||||||
# The CompilationError must have been caused by a C++ exception with this text.
|
# The CompilationError must have been caused by a C++ exception with this text.
|
||||||
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
|
||||||
else:
|
else:
|
||||||
_test_binary(dtype_x, dtype_y, expr, device=device)
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
|
||||||
|
(dtype_x, dtype_y, op)
|
||||||
|
for op in ['<<', '>>']
|
||||||
|
for dtype_x in int_dtypes + uint_dtypes
|
||||||
|
for dtype_y in int_dtypes + uint_dtypes
|
||||||
|
])
|
||||||
|
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
|
||||||
|
expr = f'x {op} y'
|
||||||
|
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
|
||||||
|
dtype_z = f'uint{bw}'
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
|
|||||||
])
|
])
|
||||||
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
||||||
expr = f'x {op} y'
|
expr = f'x {op} y'
|
||||||
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
|
||||||
|
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
|
||||||
|
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
|
||||||
|
else:
|
||||||
|
numpy_expr = None
|
||||||
|
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
@pytest.mark.parametrize("dtype_x, expr", [
|
@pytest.mark.parametrize("dtype_x, expr", [
|
||||||
(dtype_x, ' -x') for dtype_x in dtypes
|
(dtype_x, ' -x') for dtype_x in dtypes
|
||||||
] + [\
|
] + [
|
||||||
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
(dtype_x, ' ~x') for dtype_x in int_dtypes
|
||||||
])
|
])
|
||||||
def test_unary_op(dtype_x, expr, device='cuda'):
|
def test_unary_op(dtype_x, expr, device='cuda'):
|
||||||
_test_unary(dtype_x, expr, device=device)
|
_test_unary(dtype_x, expr, device=device)
|
||||||
|
|
||||||
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("expr, dtype_str", [
|
@pytest.mark.parametrize("expr, dtype_str", [
|
||||||
(f'x[{s}]', 'int32')
|
(f'x[{s}]', d)
|
||||||
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
|
||||||
|
for d in ['int32', 'uint32', 'uint16']
|
||||||
])
|
])
|
||||||
def test_index1d(expr, dtype_str, device='cuda'):
|
def test_index1d(expr, dtype_str, device='cuda'):
|
||||||
rank_x = expr.count(':')
|
rank_x = expr.count(':')
|
||||||
@@ -364,9 +457,9 @@ def test_tuples():
|
|||||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||||
[
|
[
|
||||||
('add', 'float16', mode),
|
('add', 'float16', mode),
|
||||||
('add', 'int32', mode), ('add', 'float32', mode),
|
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||||
('max', 'int32', mode), ('max', 'float32', mode),
|
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||||
('min', 'int32', mode), ('min', 'float32', mode),
|
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||||
]
|
]
|
||||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||||
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
|||||||
if exact:
|
if exact:
|
||||||
assert z_ref.item() == to_numpy(z_tri).item()
|
assert z_ref.item() == to_numpy(z_tri).item()
|
||||||
else:
|
else:
|
||||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
|
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
|||||||
('float32', 'bfloat16', False),
|
('float32', 'bfloat16', False),
|
||||||
('bfloat16', 'float32', False),
|
('bfloat16', 'float32', False),
|
||||||
('float32', 'int32', True),
|
('float32', 'int32', True),
|
||||||
]
|
] + [
|
||||||
)
|
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||||
|
] + [
|
||||||
|
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||||
|
])
|
||||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||||
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
@pytest.mark.parametrize("dtype_str, shape, axis", [
|
||||||
('float32', (1, 1024), 1)
|
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
|
||||||
])
|
])
|
||||||
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
|
|||||||
pass
|
pass
|
||||||
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
|
||||||
kernel[(1, )](x)
|
kernel[(1, )](x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value, value_type", [
|
||||||
|
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
|
||||||
|
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
|
||||||
|
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
|
||||||
|
])
|
||||||
|
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(VALUE, X):
|
||||||
|
pass
|
||||||
|
|
||||||
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
|
pgm = kernel[(1, )](value, x)
|
||||||
|
|
||||||
|
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||||
|
triton_ir = pgm.asm['ttir']
|
||||||
|
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||||
|
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||||
|
assert ir_value_type == value_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value, overflow",
|
||||||
|
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||||
|
)
|
||||||
|
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(VALUE, X):
|
||||||
|
pass
|
||||||
|
|
||||||
|
x = torch.tensor([3.14159], device='cuda')
|
||||||
|
|
||||||
|
if overflow:
|
||||||
|
with pytest.raises(RuntimeError, match='integer overflow'):
|
||||||
|
kernel[(1, )](value, x)
|
||||||
|
else:
|
||||||
|
kernel[(1, )](value, x)
|
||||||
|
@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
|
|||||||
N = x.numel()
|
N = x.numel()
|
||||||
grid = (triton.cdiv(N, BLOCK),)
|
grid = (triton.cdiv(N, BLOCK),)
|
||||||
kernel[grid](x, N, seed)
|
kernel[grid](x, N, seed)
|
||||||
|
assert all((x >= 0) & (x <= 1))
|
||||||
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
|
||||||
|
|
||||||
# test normal PRNG
|
# test normal PRNG
|
||||||
|
@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
return triton.language.constexpr(not op)
|
return triton.language.constexpr(not op)
|
||||||
if isinstance(op, triton.language.core.constexpr):
|
if isinstance(op, triton.language.core.constexpr):
|
||||||
op = op.value
|
op = op.value
|
||||||
# print(op)
|
|
||||||
fn = {
|
fn = {
|
||||||
ast.USub: '__neg__',
|
ast.USub: '__neg__',
|
||||||
ast.UAdd: '__pos__',
|
ast.UAdd: '__pos__',
|
||||||
@@ -503,6 +502,7 @@ class Binary:
|
|||||||
self.shared_mem = shared_mem
|
self.shared_mem = shared_mem
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
|
|
||||||
|
|
||||||
class LoadedBinary:
|
class LoadedBinary:
|
||||||
def __init__(self, device: int, bin: Binary):
|
def __init__(self, device: int, bin: Binary):
|
||||||
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
module, kernel = _triton.code_gen.load_binary(bin.backend,
|
||||||
@@ -571,24 +571,33 @@ class Kernel:
|
|||||||
torch.int16: 'i16',
|
torch.int16: 'i16',
|
||||||
torch.int32: 'i32',
|
torch.int32: 'i32',
|
||||||
torch.int64: 'i64',
|
torch.int64: 'i64',
|
||||||
|
triton.language.uint8: 'u8',
|
||||||
|
triton.language.uint16: 'u16',
|
||||||
|
triton.language.uint32: 'u32',
|
||||||
|
triton.language.uint64: 'u64',
|
||||||
}
|
}
|
||||||
if hasattr(obj, 'data_ptr'):
|
if hasattr(obj, 'data_ptr'):
|
||||||
return type_names[obj.dtype]
|
return type_names[obj.dtype]
|
||||||
if isinstance(obj, triton.language.core.constexpr):
|
if isinstance(obj, triton.language.core.constexpr):
|
||||||
obj = obj.value
|
obj = obj.value
|
||||||
if isinstance(obj, int):
|
if isinstance(obj, int):
|
||||||
if abs(obj) <= 0xffffffff:
|
if -2**31 <= obj < 2**31:
|
||||||
return 'I'
|
return 'i32'
|
||||||
return 'L'
|
elif 2**31 <= obj < 2**32:
|
||||||
|
return 'u32'
|
||||||
|
elif -2**63 <= obj < 2**63:
|
||||||
|
return 'i64'
|
||||||
|
elif 2**63 <= obj < 2**64:
|
||||||
|
return 'u64'
|
||||||
|
else:
|
||||||
|
raise ValueError(f'integer overflow representing {obj}')
|
||||||
if isinstance(obj, float):
|
if isinstance(obj, float):
|
||||||
return 'f'
|
return 'f'
|
||||||
if isinstance(obj, bool):
|
if isinstance(obj, bool):
|
||||||
return 'B'
|
return 'B'
|
||||||
if isinstance(obj, str):
|
if isinstance(obj, str):
|
||||||
return 'str'
|
return 'str'
|
||||||
assert False
|
raise NotImplementedError(f'could not compute type name for {obj}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_triton_ir(context, obj):
|
def _to_triton_ir(context, obj):
|
||||||
@@ -607,6 +616,10 @@ class Kernel:
|
|||||||
'i16': _triton.ir.type.get_int16,
|
'i16': _triton.ir.type.get_int16,
|
||||||
'i32': _triton.ir.type.get_int32,
|
'i32': _triton.ir.type.get_int32,
|
||||||
'i64': _triton.ir.type.get_int64,
|
'i64': _triton.ir.type.get_int64,
|
||||||
|
'u8': _triton.ir.type.get_uint8,
|
||||||
|
'u16': _triton.ir.type.get_uint16,
|
||||||
|
'u32': _triton.ir.type.get_uint32,
|
||||||
|
'u64': _triton.ir.type.get_uint64,
|
||||||
}
|
}
|
||||||
# convert torch.Tensor to Triton IR pointers
|
# convert torch.Tensor to Triton IR pointers
|
||||||
if hasattr(obj, 'data_ptr'):
|
if hasattr(obj, 'data_ptr'):
|
||||||
@@ -1165,4 +1178,15 @@ class TensorWrapper:
|
|||||||
|
|
||||||
|
|
||||||
def reinterpret(tensor, dtype):
|
def reinterpret(tensor, dtype):
|
||||||
|
if isinstance(tensor, TensorWrapper):
|
||||||
|
if dtype == tensor.base.dtype:
|
||||||
|
# Reinterpreting to the original interpretation; return the base.
|
||||||
|
return tensor.base
|
||||||
|
else:
|
||||||
|
# Reinterpreting a wrapped tensor to a different type.
|
||||||
|
return TensorWrapper(tensor.base, dtype)
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
# A new wrapper is needed around an unwrapped tensor.
|
||||||
return TensorWrapper(tensor, dtype)
|
return TensorWrapper(tensor, dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')
|
||||||
|
@@ -9,9 +9,16 @@ def _to_ir(x, builder):
|
|||||||
if isinstance(x, bool):
|
if isinstance(x, bool):
|
||||||
return builder.get_int1(x)
|
return builder.get_int1(x)
|
||||||
elif isinstance(x, int):
|
elif isinstance(x, int):
|
||||||
if x.__abs__() <= 2**31:
|
if -2**31 <= x < 2**31:
|
||||||
return builder.get_int32(x)
|
return builder.get_int32(x)
|
||||||
|
elif 2**31 <= x < 2**32:
|
||||||
|
return builder.get_uint32(x)
|
||||||
|
elif -2**63 <= x < 2**63:
|
||||||
return builder.get_int64(x)
|
return builder.get_int64(x)
|
||||||
|
elif 2**63 <= x < 2**64:
|
||||||
|
return builder.get_uint64(x)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Nonrepresentable integer {x}.')
|
||||||
elif isinstance(x, float):
|
elif isinstance(x, float):
|
||||||
return builder.get_float32(x)
|
return builder.get_float32(x)
|
||||||
elif isinstance(x, constexpr):
|
elif isinstance(x, constexpr):
|
||||||
@@ -83,6 +90,14 @@ class dtype:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_key_part(self) -> str:
|
||||||
|
"""See cache_key_part() in triton.cc."""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'triton.language.{self.name}'
|
||||||
|
|
||||||
|
|
||||||
class pointer_dtype:
|
class pointer_dtype:
|
||||||
def __init__(self, element_ty):
|
def __init__(self, element_ty):
|
||||||
@@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8)
|
|||||||
int16 = dtype(ir.type.get_int16)
|
int16 = dtype(ir.type.get_int16)
|
||||||
int32 = dtype(ir.type.get_int32)
|
int32 = dtype(ir.type.get_int32)
|
||||||
int64 = dtype(ir.type.get_int64)
|
int64 = dtype(ir.type.get_int64)
|
||||||
|
uint8 = dtype(ir.type.get_uint8)
|
||||||
|
uint16 = dtype(ir.type.get_uint16)
|
||||||
|
uint32 = dtype(ir.type.get_uint32)
|
||||||
|
uint64 = dtype(ir.type.get_uint64)
|
||||||
float8 = dtype(ir.type.get_fp8)
|
float8 = dtype(ir.type.get_fp8)
|
||||||
float16 = dtype(ir.type.get_fp16)
|
float16 = dtype(ir.type.get_fp16)
|
||||||
bfloat16 = dtype(ir.type.get_bf16)
|
bfloat16 = dtype(ir.type.get_bf16)
|
||||||
@@ -120,6 +139,10 @@ class block:
|
|||||||
if ir_type.is_int16(): return int16
|
if ir_type.is_int16(): return int16
|
||||||
if ir_type.is_int32(): return int32
|
if ir_type.is_int32(): return int32
|
||||||
if ir_type.is_int64(): return int64
|
if ir_type.is_int64(): return int64
|
||||||
|
if ir_type.is_uint8(): return uint8
|
||||||
|
if ir_type.is_uint16(): return uint16
|
||||||
|
if ir_type.is_uint32(): return uint32
|
||||||
|
if ir_type.is_uint64(): return uint64
|
||||||
if ir_type.is_fp8(): return float8
|
if ir_type.is_fp8(): return float8
|
||||||
if ir_type.is_fp16(): return float16
|
if ir_type.is_fp16(): return float16
|
||||||
if ir_type.is_bf16(): return bfloat16
|
if ir_type.is_bf16(): return bfloat16
|
||||||
|
Reference in New Issue
Block a user