uint8, uint16, uint32, and uint64 in kernels (#413)

A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
This commit is contained in:
Madeleine Thompson
2022-01-05 15:27:17 -08:00
committed by GitHub
parent d8db0308cb
commit 0ab9d67bad
12 changed files with 444 additions and 110 deletions

View File

@@ -40,6 +40,8 @@ public:
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_uint32(uint32_t val);
value *get_uint64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
@@ -50,6 +52,10 @@ public:
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_uint8_ty();
type *get_uint16_ty();
type *get_uint32_ty();
type *get_uint64_ty();
type *get_half_ty();
type *get_float_ty();
type *get_double_ty();

View File

@@ -28,6 +28,7 @@ public:
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
// Block types

View File

@@ -15,6 +15,8 @@ class value;
class integer_type;
class constant_int;
enum class signedness { SIGNED, UNSIGNED };
/* Type */
class type {
public:
@@ -58,6 +60,8 @@ public:
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
signedness get_integer_signedness() const;
bool is_integer_signed() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
@@ -80,8 +84,9 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_integer_ty(unsigned bitwidth, signedness sn) {
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
@@ -109,6 +114,10 @@ public:
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
static integer_type *get_uint8_ty(context &ctx);
static integer_type *get_uint16_ty(context &ctx);
static integer_type *get_uint32_ty(context &ctx);
static integer_type *get_uint64_ty(context &ctx);
// repr
std::string tile_repr() const {
@@ -135,7 +144,7 @@ public:
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
@@ -158,18 +167,21 @@ class integer_type: public type {
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
integer_type(context &ctx, unsigned bitwidth, signedness sn)
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
signedness get_signedness() const { return signedness_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
signedness signedness_;
};
class composite_type: public type{

View File

@@ -51,9 +51,15 @@ value *builder::get_int1(bool val)
value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_uint32(uint32_t val)
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_uint64(uint64_t val)
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
value *builder::get_float16(float val)
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
@@ -84,6 +90,18 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); }
type *builder::get_uint8_ty()
{ return type::get_uint8_ty(ctx_); }
type *builder::get_uint16_ty()
{ return type::get_uint16_ty(ctx_); }
type *builder::get_uint32_ty()
{ return type::get_uint32_ty(ctx_); }
type *builder::get_uint64_ty()
{ return type::get_uint64_ty(ctx_); }
type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); }

View File

@@ -19,12 +19,16 @@ context_impl::context_impl(context &ctx)
fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID),
// integers
int1_ty(ctx, 1),
int8_ty(ctx, 8),
int16_ty(ctx, 16),
int32_ty(ctx, 32),
int64_ty(ctx, 64),
int128_ty(ctx, 128){
int1_ty(ctx, 1, signedness::SIGNED),
int8_ty(ctx, 8, signedness::SIGNED),
int16_ty(ctx, 16, signedness::SIGNED),
int32_ty(ctx, 32, signedness::SIGNED),
int64_ty(ctx, 64, signedness::SIGNED),
int128_ty(ctx, 128, signedness::SIGNED),
uint8_ty(ctx, 8, signedness::UNSIGNED),
uint16_ty(ctx, 16, signedness::UNSIGNED),
uint32_ty(ctx, 32, signedness::UNSIGNED),
uint64_ty(ctx, 64, signedness::UNSIGNED){
}

View File

@@ -1,14 +1,12 @@
#include "triton/ir/dispatch.h"
#include <iostream>
namespace triton{
namespace ir{
namespace triton {
namespace ir {
ir::value* throw_unreachable(std::string key) {
[[ noreturn ]] void throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
return 0;
}
//===----------------------------------------------------------------------===//
@@ -30,7 +28,19 @@ ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
return a_rank > b_rank ? a_ty : b_ty;
auto a_sn = a_ty->get_integer_signedness();
auto b_sn = b_ty->get_integer_signedness();
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if (a_sn == b_sn) {
return a_rank > b_rank ? a_ty : b_ty;
} else if (a_sn == signedness::UNSIGNED) {
return a_rank >= b_rank ? a_ty : b_ty;
} else if (b_sn == signedness::UNSIGNED) {
return b_rank >= a_rank ? b_ty : a_ty;
} else {
throw_unreachable("integer_promote");
}
}
enum class DivOrMod { NO, YES };
@@ -58,6 +68,9 @@ ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod)
throw_unreachable("computation_type");
// 4 ) both operands are integer and undergo
// integer promotion
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
return integer_promote(a_ty, b_ty);
}
@@ -116,7 +129,7 @@ ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builde
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
return throw_unreachable("add");
throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -131,7 +144,7 @@ ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builde
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
return throw_unreachable("sub");
throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -143,7 +156,7 @@ ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builde
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
return throw_unreachable("mul");
throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -170,7 +183,7 @@ ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *bu
}
// unreachable
else
return throw_unreachable("div");
throw_unreachable("div");
return builder->create_fdiv(input, other);
}
@@ -182,21 +195,34 @@ ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *b
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
return builder->create_sdiv(input, other);
if (ret_ty->is_integer_signed()) {
return builder->create_sdiv(input, other);
} else {
return builder->create_udiv(input, other);
}
}
return throw_unreachable("floordiv");
throw_unreachable("floordiv");
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(input, other);
return throw_unreachable("mod");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
if (scalar_ty->is_integer_signed()) {
return builder->create_srem(input, other);
} else {
return builder->create_urem(input, other);
}
}
throw_unreachable("mod");
}
@@ -206,10 +232,11 @@ void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
input = dispatch::cast(input, other_sca_ty, builder);
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
if (ret_sca_ty != input_sca_ty)
input = dispatch::cast(input, ret_sca_ty, builder);
if (ret_sca_ty != other_sca_ty)
other = dispatch::cast(other, ret_sca_ty, builder);
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -276,9 +303,14 @@ ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builde
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(input, other);
return throw_unreachable("greater_than");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGT(input, other);
} else {
return builder->create_icmpUGT(input, other);
}
}
throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -288,9 +320,14 @@ ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::build
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(input, other);
return throw_unreachable("greater_equal");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGE(input, other);
} else {
return builder->create_icmpUGE(input, other);
}
}
throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -300,9 +337,14 @@ ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(input, other);
return throw_unreachable("less_than");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLT(input, other);
} else {
return builder->create_icmpULT(input, other);
}
}
throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -312,9 +354,14 @@ ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(input, other);
return throw_unreachable("less_equal");
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLE(input, other);
} else {
return builder->create_icmpULE(input, other);
}
}
throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -326,7 +373,7 @@ ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *buil
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
return throw_unreachable("equal");
throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
@@ -338,7 +385,7 @@ ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
return throw_unreachable("equal");
throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
@@ -461,8 +508,11 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
return builder->create_int_cast(input, dst_ty, sign_extend);
}
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
@@ -472,7 +522,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if(src_sca_ty->is_bool_ty())
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
@@ -493,7 +543,7 @@ ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *build
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
return throw_unreachable("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
//===----------------------------------------------------------------------===//
@@ -594,8 +644,13 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
}
}
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
@@ -611,9 +666,14 @@ ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask,
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
// direct call to atomic_min for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
}
}
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
@@ -699,7 +759,7 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
return throw_unreachable(name);
throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {

View File

@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
signedness type::get_integer_signedness() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
bool type::is_integer_signed() const {
if (id_ != IntegerTyID) {
throw std::logic_error("type is " + repr() + ", not integer");
}
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
}
unsigned type::get_tile_bitwidth() const
{ return ((block_type*)(this))->get_bitwidth(); }
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }

View File

@@ -109,6 +109,24 @@ std::string pow2_divisor(long N){
return "1";
}
// Returns something like "int16", whether dtype is a torch.dtype or
// triton.language.dtype.
std::string dtype_cache_key_part(const py::object& dtype) {
if (py::hasattr(dtype, "cache_key_part")) {
// Presumed to be a triton.language.dtype.
return std::string(py::str(py::getattr(dtype, "cache_key_part")));
} else {
// Remove 'torch.' prefix from repr of torch.dtype.
py::object repr = py::repr(dtype);
size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
}
return std::string(repr_ptr + 6, repr_len - 6);
}
}
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
@@ -136,22 +154,34 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
cache_key += "1";
continue;
}
// long and int have different kernels
if(!overflow & (std::abs(value) <= 0xffffffff)){
// int32, uint32, int64, and uint64 have different kernels
if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
cache_key += "int32";
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
}
else{
} else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
cache_key += "uint32";
params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
std::memcpy(params_ptr, &value, 4);
params_ptr += 4;
} else if (!overflow) {
cache_key += "int64";
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
if(overflow){
unsigned long long uvalue = PyLong_AsUnsignedLongLong(arg_ptr);
std::memcpy(&value, &uvalue, 8);
}
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
} else {
if (PyErr_Occurred()) {
throw std::logic_error("An error occurred?");
}
unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
if (PyErr_Occurred()) {
throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
}
cache_key += "uint64";
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &unsigned_value, 8);
params_ptr += 8;
}
if(!specialize)
continue;
@@ -185,12 +215,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
py::object dtype = arg.attr("dtype");
py::object repr = py::repr(dtype);
assert(!strncmp((const char*)PyUnicode_1BYTE_DATA(repr.ptr()), "torch.", 6));
const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()) + 6; // remove 'torch.'
size_t len = PyUnicode_GET_LENGTH(repr.ptr()) - 6;
cache_key += std::string(start, len);
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
cache_key += pow2_divisor(value);
@@ -628,6 +653,10 @@ void init_triton_ir(py::module &&m) {
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
.def("is_void", &ir::type::is_void_ty)
.def("is_fp8", &ir::type::is_fp8_ty)
@@ -635,11 +664,15 @@ void init_triton_ir(py::module &&m) {
.def("is_bf16", &ir::type::is_bf16_ty)
.def("is_fp32", &ir::type::is_fp32_ty)
.def("is_fp64", &ir::type::is_fp64_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
@@ -703,6 +736,8 @@ void init_triton_ir(py::module &&m) {
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)
.def("get_int64", &ir::builder::get_int64, ret::reference)
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference);

View File

@@ -1,7 +1,7 @@
import copy
import itertools
import re
from typing import Optional
from typing import Optional, Union
import numpy as np
import pytest
@@ -10,17 +10,20 @@ from numpy.random import RandomState
import triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + float_dtypes
dtypes = int_dtypes + uint_dtypes + float_dtypes
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
@@ -30,9 +33,11 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
if rs is None:
rs = RandomState(seed=17)
dtype = getattr(np, dtype_str)
if dtype_str in int_dtypes:
if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
x = rs.randint(iinfo.min, iinfo.max, shape, dtype=dtype)
low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max)
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out.
return x
elif dtype_str in float_dtypes:
@@ -41,15 +46,31 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None):
raise RuntimeError(f'Unknown dtype {dtype_str}')
def to_triton(x: np.ndarray, device='cuda') -> torch.Tensor:
# For now, this always converts to a torch tensor, but when we add unsigned
# integers, it will also support TensorWrapper, since torch doesn't have
# unsigned support.
return torch.tensor(x, device=device)
def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]:
t = x.dtype.name
if t in uint_dtypes:
signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
return torch.tensor(x, device=device)
def torch_dtype_name(dtype) -> str:
if isinstance(dtype, triton.language.dtype):
return dtype.name
elif isinstance(dtype, torch.dtype):
# 'torch.int64' -> 'int64'
m = re.match(r'^torch\.(\w+)$', str(dtype))
return m.group(1)
else:
raise TypeError(f'not a triton or torch dtype: {type(dtype)}')
def to_numpy(x):
if isinstance(x, torch.Tensor):
if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor):
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
@@ -103,18 +124,33 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations.
narrower floating point types than numpy in mixed operations, and because
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
numpy/pytorch do not.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
('float16', 'uint16'): np.float16,
('float16', 'uint32'): np.float16,
('float16', 'uint64'): np.float16,
('int8', 'uint8'): np.uint8,
('int8', 'uint16'): np.uint16,
('int8', 'uint32'): np.uint32,
('int8', 'uint64'): np.uint64,
('int16', 'uint16'): np.uint16,
('int16', 'uint32'): np.uint32,
('int16', 'uint64'): np.uint64,
('int32', 'uint32'): np.uint32,
('int32', 'uint64'): np.uint64,
('int64', 'uint64'): np.uint64,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda'):
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None):
SIZE = 128
# define the kernel / launch-grid
@triton.jit
@@ -129,7 +165,7 @@ def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y=
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
@@ -158,6 +194,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
('int64', 'float16'),
('int64', 'float32'),
('int64', 'float64'),
('uint16', 'float16'),
('uint16', 'float32'),
('uint32', 'float16'),
('uint32', 'float32'),
('uint64', 'float16'),
('uint64', 'float32'),
('uint64', 'float64'),
]
# ---------------
@@ -171,7 +214,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
])
def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
expr = f' x {op} y'
if op == '%' and dtype_x in int_dtypes and dtype_y in int_dtypes:
if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes:
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = 'np.fmod(x, y)'
elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'):
@@ -179,15 +222,38 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)'
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y):
with pytest.raises(AssertionError, match='Not equal to tolerance'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
elif (op in ('%', '/') and
((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y",
[(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] +
[(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]
)
def test_floordiv(dtype_x, dtype_y, device='cuda'):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
# ---------------
# test bitwise ops
# ---------------
@@ -199,13 +265,33 @@ def test_bin_op(dtype_x, dtype_y, op, device='cuda'):
])
def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
with pytest.raises(triton.code_gen.CompilationError) as exc_info:
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device)
# The CompilationError must have been caused by a C++ exception with this text.
assert re.match('invalid operands of type', str(exc_info.value.__cause__))
else:
_test_binary(dtype_x, dtype_y, expr, device=device)
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['<<', '>>']
for dtype_x in int_dtypes + uint_dtypes
for dtype_y in int_dtypes + uint_dtypes
])
def test_shift_op(dtype_x, dtype_y, op, device='cuda'):
expr = f'x {op} y'
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
dtype_z = f'uint{bw}'
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65)
# ---------------
@@ -230,7 +316,13 @@ ops = ['==', '!=', '>', '<', '>=', '<=']
])
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
expr = f'x {op} y'
_test_binary(dtype_x, dtype_y, expr, mode_x=mode_x, mode_y=mode_y, device=device)
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device)
# ---------------
@@ -238,9 +330,9 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
(dtype_x, ' -x') for dtype_x in dtypes
] + [\
] + [
(dtype_x, ' ~x') for dtype_x in int_dtypes
])
])
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
@@ -275,8 +367,9 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', 'int32')
(f'x[{s}]', d)
for s in ['None, :', ':, None', 'None, :, :', ':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
rank_x = expr.count(':')
@@ -364,9 +457,9 @@ def test_tuples():
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'int32', mode), ('add', 'float32', mode),
('max', 'int32', mode), ('max', 'float32', mode),
('min', 'int32', mode), ('min', 'float32', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
@@ -409,7 +502,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.001)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
# ---------------
@@ -423,8 +516,11 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
]
)
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
] + [
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
@@ -487,7 +583,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, axis", [
('float32', (1, 1024), 1)
(dtype, (1, 1024), 1) for dtype in ['float32', 'uint32']
])
def test_reduce2d(dtype_str, shape, axis, device='cuda'):
# triton kernel
@@ -762,3 +858,43 @@ def test_noop(device='cuda'):
pass
x = to_triton(numpy_random((1,), dtype_str='int32'), device=device)
kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31-1, 'i32'),
(2**31, 'u32'), (2**32-1, 'u32'), (2**32, 'i64'), (2**63-1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64-1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
assert ir_value_type == value_type
@pytest.mark.parametrize(
"value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
)
def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
if overflow:
with pytest.raises(RuntimeError, match='integer overflow'):
kernel[(1, )](value, x)
else:
kernel[(1, )](value, x)

View File

@@ -147,6 +147,7 @@ def test_rand(size, seed, device='cuda'):
N = x.numel()
grid = (triton.cdiv(N, BLOCK),)
kernel[grid](x, N, seed)
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01
# test normal PRNG

View File

@@ -331,7 +331,6 @@ class CodeGenerator(ast.NodeVisitor):
return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr):
op = op.value
# print(op)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -503,6 +502,7 @@ class Binary:
self.shared_mem = shared_mem
self.num_warps = num_warps
class LoadedBinary:
def __init__(self, device: int, bin: Binary):
module, kernel = _triton.code_gen.load_binary(bin.backend,
@@ -571,24 +571,33 @@ class Kernel:
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
triton.language.uint8: 'u8',
triton.language.uint16: 'u16',
triton.language.uint32: 'u32',
triton.language.uint64: 'u64',
}
if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr):
obj = obj.value
if isinstance(obj, int):
if abs(obj) <= 0xffffffff:
return 'I'
return 'L'
if -2**31 <= obj < 2**31:
return 'i32'
elif 2**31 <= obj < 2**32:
return 'u32'
elif -2**63 <= obj < 2**63:
return 'i64'
elif 2**63 <= obj < 2**64:
return 'u64'
else:
raise ValueError(f'integer overflow representing {obj}')
if isinstance(obj, float):
return 'f'
if isinstance(obj, bool):
return 'B'
if isinstance(obj, str):
return 'str'
assert False
raise NotImplementedError(f'could not compute type name for {obj}')
@staticmethod
def _to_triton_ir(context, obj):
@@ -607,6 +616,10 @@ class Kernel:
'i16': _triton.ir.type.get_int16,
'i32': _triton.ir.type.get_int32,
'i64': _triton.ir.type.get_int64,
'u8': _triton.ir.type.get_uint8,
'u16': _triton.ir.type.get_uint16,
'u32': _triton.ir.type.get_uint32,
'u64': _triton.ir.type.get_uint64,
}
# convert torch.Tensor to Triton IR pointers
if hasattr(obj, 'data_ptr'):
@@ -1165,4 +1178,15 @@ class TensorWrapper:
def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
# Reinterpreting to the original interpretation; return the base.
return tensor.base
else:
# Reinterpreting a wrapped tensor to a different type.
return TensorWrapper(tensor.base, dtype)
elif isinstance(tensor, torch.Tensor):
# A new wrapper is needed around an unwrapped tensor.
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f'Cannot reinterpret a {type(tensor)}.')

View File

@@ -9,9 +9,16 @@ def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
if x.__abs__() <= 2**31:
if -2**31 <= x < 2**31:
return builder.get_int32(x)
return builder.get_int64(x)
elif 2**31 <= x < 2**32:
return builder.get_uint32(x)
elif -2**63 <= x < 2**63:
return builder.get_int64(x)
elif 2**63 <= x < 2**64:
return builder.get_uint64(x)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
return builder.get_float32(x)
elif isinstance(x, constexpr):
@@ -83,6 +90,14 @@ class dtype:
def __str__(self):
return self.name
@property
def cache_key_part(self) -> str:
"""See cache_key_part() in triton.cc."""
return self.name
def __repr__(self):
return f'triton.language.{self.name}'
class pointer_dtype:
def __init__(self, element_ty):
@@ -102,6 +117,10 @@ int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16)
int32 = dtype(ir.type.get_int32)
int64 = dtype(ir.type.get_int64)
uint8 = dtype(ir.type.get_uint8)
uint16 = dtype(ir.type.get_uint16)
uint32 = dtype(ir.type.get_uint32)
uint64 = dtype(ir.type.get_uint64)
float8 = dtype(ir.type.get_fp8)
float16 = dtype(ir.type.get_fp16)
bfloat16 = dtype(ir.type.get_bf16)
@@ -120,6 +139,10 @@ class block:
if ir_type.is_int16(): return int16
if ir_type.is_int32(): return int32
if ir_type.is_int64(): return int64
if ir_type.is_uint8(): return uint8
if ir_type.is_uint16(): return uint16
if ir_type.is_uint32(): return uint32
if ir_type.is_uint64(): return uint64
if ir_type.is_fp8(): return float8
if ir_type.is_fp16(): return float16
if ir_type.is_bf16(): return bfloat16