diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index fe85be947..2b6bc6ab3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -38,8 +38,10 @@ public: iterator get_insert_point() { return insert_point_;} // Constants value *get_int1(bool val); - value *get_int32(uint32_t val); - value *get_int64(uint64_t val); + value *get_int32(int32_t val); + value *get_int64(int64_t val); + value *get_uint32(uint32_t val); + value *get_uint64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -50,9 +52,11 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); - type *get_fp8_ty(); + type *get_uint8_ty(); + type *get_uint16_ty(); + type *get_uint32_ty(); + type *get_uint64_ty(); type *get_half_ty(); - type *get_bf16_ty(); type *get_float_ty(); type *get_double_ty(); // Insert @@ -70,9 +74,7 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); // Cast instructions - value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); - value* create_int_to_ptr(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty); @@ -91,11 +93,11 @@ public: value *create_frem(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sdiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); @@ -143,22 +145,11 @@ public: value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); - // Atomic instruction - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); - value *create_atomic_max(value *ptr, value *val, value *msk); - value *create_atomic_umax(value *ptr, value *val, value *msk); - value *create_atomic_min(value *ptr, value *val, value *msk); - value *create_atomic_umin(value *ptr, value *val, value *msk); - value *create_atomic_fadd(value *ptr, value *val, value *msk); - value *create_atomic_add(value *ptr, value *val, value *msk); - value *create_atomic_and(value *ptr, value *val, value *msk); - value *create_atomic_or(value *ptr, value *val, value *msk); - value *create_atomic_xor(value *ptr, value *val, value *msk); - value *create_atomic_xchg(value *ptr, value *val, value *msk); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); value *create_sin(value* arg); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index ef20af6b7..081ea249d 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -26,6 +26,7 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; + integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, std::unique_ptr> ptr_tys; // Block types diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h new file mode 100644 index 000000000..ef14043dd --- /dev/null +++ b/include/triton/ir/dispatch.h @@ -0,0 +1,113 @@ +#pragma once + +#ifndef _TRITON_IR_DISPATCH_H_ +#define _TRITON_IR_DISPATCH_H_ + +#include "triton/ir/builder.h" +#include + +namespace triton{ +namespace ir{ + + +/*---------------------------------------------- + higher level functions that follow the likely + semantics of most expected frontends + ----------------------------------------------*/ + +struct semantic_error: public std::runtime_error { + semantic_error(const std::string& msg): + std::runtime_error(msg) { } +}; + +struct dispatch{ + typedef ir::type::block_shapes_t shape_t; + + + // programming model + static ir::value *program_id(int axis, ir::builder *builder); + static ir::value *num_programs(int axis, ir::builder *builder); + + // binary operators + static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); + static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); + + // unary operators + static ir::value *plus(ir::value *input, ir::builder *builder); + static ir::value *minus(ir::value *input, ir::builder *builder); + static ir::value *invert(ir::value *input, ir::builder *builder); + + // comparison operators + static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); + static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); + + // block creation + static ir::value* arange(int start, int end, ir::builder *builder); + static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); + + + // casting ops + static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); + static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); + static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); + static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); + static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); + static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); + + // memory operators + static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, + const std::string& eviction_policy, int is_volatile, ir::builder *builder); + static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); + static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); + static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); + + // linear algebra + static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); + + // indexing + static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); + + // reduction + static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); + static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); + + // math + static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); + static ir::value *exp(ir::value *x, ir::builder *builder); + static ir::value *log(ir::value *x, ir::builder *builder); + static ir::value *cos(ir::value *x, ir::builder *builder); + static ir::value *sin(ir::value *x, ir::builder *builder); + static ir::value *sqrt(ir::value *x, ir::builder *builder); + + // internal (debug/optimization) + static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); + static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); + static ir::value *debug_barrier(ir::builder *builder); +}; + +} +} + +#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index ea64dfc6e..30881fd49 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -57,10 +57,26 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, builder &builder): name_(name), builder_(builder) {} - builder &get_builder() { return builder_; }; - const std::string& get_name() { return name_; }; + module(const std::string &name, builder& builder); + builder& get_builder(); + // Setters + void set_value(const std::string& name, basic_block* block, value *x); + void set_value(const std::string& name, value* x); + void set_const(const std::string& name); + void set_continue_fn(std::function fn); + // Getters + const std::map& get_values() { return values_; } + const std::map& get_types() { return types_; } + void set_values(const std::map& values) { values_ = values; } + void set_types(const std::map& types) { types_ = types; } + value *get_value(const std::string& name, basic_block* block); + value *get_value(const std::string& name); + void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } + const std::string& get_name(); + std::function get_continue_fn(); + // Seal block -- no more predecessors will be added + void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } @@ -73,14 +89,21 @@ public: const std::map& globals() const { return globals_; } // Metadata void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - const std::map &get_metadatas() const { return metadatas_; } + void print(std::ostream &os); private: std::string name_; - builder &builder_; + builder& builder_; + std::map values_; + std::map types_; + std::set const_; + std::set sealed_blocks_; + std::map> incomplete_phis_; functions_list_t functions_; symbols_map_t symbols_; + std::function continue_fn_; + std::map current_phi_; std::vector allocs_; std::map globals_; std::map metadatas_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index b1ef1ad22..47c9b5f85 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -16,6 +16,8 @@ class value; class integer_type; class constant_int; +enum class signedness { SIGNED, UNSIGNED }; + /* Type */ class type { public: @@ -59,6 +61,8 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; + signedness get_integer_signedness() const; + bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -81,6 +85,9 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } + bool is_integer_ty(unsigned bitwidth, signedness sn) { + return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; + } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -108,6 +115,10 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); + static integer_type *get_uint8_ty(context &ctx); + static integer_type *get_uint16_ty(context &ctx); + static integer_type *get_uint32_ty(context &ctx); + static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -134,7 +145,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -157,18 +168,21 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth) - : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} + integer_type(context &ctx, unsigned bitwidth, signedness sn) + : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } public: // accessors unsigned get_bitwidth() const { return bitwidth_; } + signedness get_signedness() const { return signedness_; } + // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; + signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 9b8a2a45e..fff73e665 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -48,12 +48,18 @@ void builder::set_insert_point(basic_block *block){ value *builder::get_int1(bool val) { return constant_int::get(type::get_int1_ty(ctx_), val); } -value *builder::get_int32(uint32_t val) +value *builder::get_int32(int32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} -value *builder::get_int64(uint64_t val) +value *builder::get_uint32(uint32_t val) +{ return constant_int::get(type::get_uint32_ty(ctx_), val);} + +value *builder::get_int64(int64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} +value *builder::get_uint64(uint64_t val) +{ return constant_int::get(type::get_uint64_ty(ctx_), val);} + value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -84,15 +90,21 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } -type *builder::get_fp8_ty() -{ return type::get_fp8_ty(ctx_); } +type *builder::get_uint8_ty() +{ return type::get_uint8_ty(ctx_); } + +type *builder::get_uint16_ty() +{ return type::get_uint16_ty(ctx_); } + +type *builder::get_uint32_ty() +{ return type::get_uint32_ty(ctx_); } + +type *builder::get_uint64_ty() +{ return type::get_uint64_ty(ctx_); } type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } -type *builder::get_bf16_ty() -{ return type::get_bf16_ty(ctx_); } - type *builder::get_float_ty() { return type::get_fp32_ty(ctx_); } @@ -127,8 +139,6 @@ value *builder::create_ret_void() { return create_cast(OPCODE, src, dst_ty);\ } -DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) -DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) @@ -321,28 +331,6 @@ value *builder::create_downcast(value *arg) { return insert(downcast_inst::create(arg)); } -// - -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} - -#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ - value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ - return create_atomic_rmw(OPCODE, ptr, val, mask);\ - } - -DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) -DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) -DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) -DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) -DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) -DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) -DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) -DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) -DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) -DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) - //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -359,6 +347,9 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); +} value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 0fc65ddc2..90b109b9b 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,12 +19,18 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1), - int8_ty(ctx, 8), - int16_ty(ctx, 16), - int32_ty(ctx, 32), - int64_ty(ctx, 64), - int128_ty(ctx, 128) {} + int1_ty(ctx, 1, signedness::SIGNED), + int8_ty(ctx, 8, signedness::SIGNED), + int16_ty(ctx, 16, signedness::SIGNED), + int32_ty(ctx, 32, signedness::SIGNED), + int64_ty(ctx, 64, signedness::SIGNED), + int128_ty(ctx, 128, signedness::SIGNED), + uint8_ty(ctx, 8, signedness::UNSIGNED), + uint16_ty(ctx, 16, signedness::UNSIGNED), + uint32_ty(ctx, 32, signedness::UNSIGNED), + uint64_ty(ctx, 64, signedness::UNSIGNED){ + +} //===----------------------------------------------------------------------===// // context diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc new file mode 100644 index 000000000..664fbb983 --- /dev/null +++ b/lib/ir/dispatch.cc @@ -0,0 +1,882 @@ +#include "triton/ir/dispatch.h" + +namespace triton { +namespace ir { + + +[[ noreturn ]] void throw_unreachable(std::string key) { + throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " + "This is likely a bug on our side."); +} + +//===----------------------------------------------------------------------===// +// Programming Model +//===----------------------------------------------------------------------===// + +ir::value *dispatch::program_id(int axis, ir::builder *builder) { + return builder->create_get_program_id(axis); +} + +ir::value *dispatch::num_programs(int axis, ir::builder *builder) { + return builder->create_get_num_programs(axis); +} + +//===----------------------------------------------------------------------===// +// Implicit Casting Utilities +//===----------------------------------------------------------------------===// + +ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ + int a_rank = a_ty->get_integer_bitwidth(); + int b_rank = b_ty->get_integer_bitwidth(); + auto a_sn = a_ty->get_integer_signedness(); + auto b_sn = b_ty->get_integer_signedness(); + // Rules for signedness taken from "Usual arithmetic conversions" on + // https://en.cppreference.com/w/c/language/conversion. + if (a_sn == b_sn) { + return a_rank > b_rank ? a_ty : b_ty; + } else if (a_sn == signedness::UNSIGNED) { + return a_rank >= b_rank ? a_ty : b_ty; + } else if (b_sn == signedness::UNSIGNED) { + return b_rank >= a_rank ? b_ty : a_ty; + } else { + throw_unreachable("integer_promote"); + } +} + +enum class DivOrMod { NO, YES }; + +ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { + context &ctx = a_ty->get_context(); + // 1) if one operand is double, the other is implicitly + // converted to double + if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) + return type::get_fp64_ty(ctx); + // 2) if one operand is float, the other is implicitly + // converted to float + if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) + return type::get_fp32_ty(ctx); + // 3 ) if one operand is half, the other is implicitly converted to half + // unless we're doing / or %, which do not exist natively in PTX for fp16. + if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { + if (div_or_mod == DivOrMod::YES) { + return type::get_fp32_ty(ctx); + } else { + return type::get_fp16_ty(ctx); + } + } + if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) + throw_unreachable("computation_type"); + // 4 ) both operands are integer and undergo + // integer promotion + if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { + throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + return integer_promote(a_ty, b_ty); +} + +//===----------------------------------------------------------------------===// +// Binary Operators +//===----------------------------------------------------------------------===// + +void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { + throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); +} + +void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ + + if(type_a->is_pointer_ty()){ + if(!allow_ptr_a) + throw_incompatible_types(type_a, type_b); + // T* + U* with T != U + if(type_b->is_pointer_ty() && (type_a != type_b)) + throw_incompatible_types(type_a, type_b); + // T* + float + if(type_b->is_floating_point_ty()) + throw_incompatible_types(type_a, type_b); + } +} + +void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, + bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, + bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { + // implicit broadcasting + std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); + // implicit typecasting + ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); + ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); + check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); + check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); + if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { + ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); + lhs = dispatch::cast(lhs, ret_sca_ty, builder); + rhs = dispatch::cast(rhs, ret_sca_ty, builder); + } +} + +ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, true); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // offset + ptr + // ptr + offset + if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) + std::swap(input, other); + if (input_scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {other}); + // float + float + else if (input_scalar_ty->is_floating_point_ty()) + return builder->create_fadd(input, other); + // int + int + else if (input_scalar_ty->is_integer_ty()) + return builder->create_add(input, other); + throw_unreachable("add"); +} + +ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, true, false); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // ptr - offset + if (scalar_ty->is_pointer_ty()) + return builder->create_gep(input, {dispatch::minus(other, builder)}); + // float + float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fsub(input, other); + // int + int + else if (scalar_ty->is_integer_ty()) + return builder->create_sub(input, other); + throw_unreachable("sub"); +} + +ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float * float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fmul(input, other); + // int * int + else if (scalar_ty->is_integer_ty()) + return builder->create_mul(input, other); + throw_unreachable("mul"); +} + +ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // float / int + if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) + other = cast(other, input_scalar_ty, builder); + // int / float + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) + input = cast(input, other_scalar_ty, builder); + // int / int (cast to float32) + else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + input = cast(input, builder->get_float_ty(), builder); + other = cast(other, builder->get_float_ty(), builder); + } + // float / float (cast to highest exponent type) + else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ + if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) + other = cast(other, input_scalar_ty, builder); + else + input = cast(input, other_scalar_ty, builder); + } + // unreachable + else + throw_unreachable("div"); + return builder->create_fdiv(input, other); +} + +ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ + ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); + input = dispatch::cast(input, ret_ty, builder); + other = dispatch::cast(other, ret_ty, builder); + if (ret_ty->is_integer_signed()) { + return builder->create_sdiv(input, other); + } else { + return builder->create_udiv(input, other); + } + } + throw_unreachable("floordiv"); +} + +ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ + ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) + throw semantic_error("both operands of fdiv must have floating point scalar type"); + binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); + ir::value* ret = builder->create_fdiv(input, other); + if(ir::binary_operator* binop = dynamic_cast(ret)) + binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); + return ret; +} + +ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); + // float % int + if (scalar_ty->is_floating_point_ty()) + return builder->create_frem(input, other); + // int % int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { + throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); + } + if (scalar_ty->is_integer_signed()) { + return builder->create_srem(input, other); + } else { + return builder->create_urem(input, other); + } + } + throw_unreachable("mod"); +} + + +void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { + binary_op_type_checking(input, other, builder, false, false, false); + ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); + ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); + if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) + throw_incompatible_types(input_sca_ty, other_sca_ty); + ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); + if (ret_sca_ty != input_sca_ty) + input = dispatch::cast(input, ret_sca_ty, builder); + if (ret_sca_ty != other_sca_ty) + other = dispatch::cast(other, ret_sca_ty, builder); +} + +ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_and(input, other); +} + +ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_or(input, other); +} + + +ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_xor(input, other); +} + + +ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_lshr(input, other); +} + + +ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { + bitwise_op_type_checking(input, other, builder); + return builder->create_shl(input, other); +} + +//===----------------------------------------------------------------------===// +// Unary Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::plus(ir::value *input, ir::builder *) { + return input; +} + +ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty()) + throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); + ir::value *_0 = ir::constant::get_null_value(input_sca_ty); + return dispatch::sub(_0, input, builder); +} + +ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { + ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); + if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) + throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); + ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); + return dispatch::xor_(input, _1, builder); +} + + +//===----------------------------------------------------------------------===// +// Comparison Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float > float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGT(input, other); + // int > int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGT(input, other); + } else { + return builder->create_icmpUGT(input, other); + } + } + throw_unreachable("greater_than"); +} + +ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float >= float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOGE(input, other); + // int >= int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSGE(input, other); + } else { + return builder->create_icmpUGE(input, other); + } + } + throw_unreachable("greater_equal"); +} + +ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLT(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLT(input, other); + } else { + return builder->create_icmpULT(input, other); + } + } + throw_unreachable("less_than"); +} + +ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float < float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOLE(input, other); + // int < int + else if (scalar_ty->is_integer_ty()) { + if (scalar_ty->is_integer_signed()) { + return builder->create_icmpSLE(input, other); + } else { + return builder->create_icmpULE(input, other); + } + } + throw_unreachable("less_equal"); +} + +ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpOEQ(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpEQ(input, other); + throw_unreachable("equal"); +} + +ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { + binary_op_type_checking(input, other, builder); + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // float == float + if (scalar_ty->is_floating_point_ty()) + return builder->create_fcmpUNE(input, other); + // int == int + else if (scalar_ty->is_integer_ty()) + return builder->create_icmpNE(input, other); + throw_unreachable("equal"); +} + +//===----------------------------------------------------------------------===// +// Block Creation +//===----------------------------------------------------------------------===// + +ir::value* dispatch::arange(int start, int end, ir::builder *builder) { + return builder->get_range(start, end); +} + +ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { + ir::value *_0 = ir::constant::get_null_value(dtype); + return builder->create_splat(_0, shape); +} + +//===----------------------------------------------------------------------===// +// Shape Manipulation +//===----------------------------------------------------------------------===// + + +ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { + unsigned numel = 1; + for(unsigned s: dst_shape) numel *= s; + if(input->get_type()->get_tile_num_elements() != numel) + throw semantic_error("cannot reshape block of different shape"); + return builder->create_reshape(input, dst_shape); +} + +ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + return builder->create_cat(lhs, rhs); +} + +ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { + if (!input->get_type()->is_block_ty()) + return builder->create_splat(input, shape); + auto src_shape = input->get_type()->get_block_shapes(); + if (src_shape.size() != shape.size()) + throw std::runtime_error("Cannot broadcast"); + if(shape == src_shape) + return input; + return builder->create_broadcast(input, shape); +} + +std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { + ir::type *lhs_ty = lhs->get_type(); + ir::type *rhs_ty = rhs->get_type(); + + // make_shape_compatible(block, scalar) + if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) + rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); + // make_shape_compatible(scalar, block) + else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) + lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); + // make_shape_compatible(block, block) + else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { + auto lhs_shape = lhs_ty->get_block_shapes(); + auto rhs_shape = rhs_ty->get_block_shapes(); + if (lhs_shape.size() != rhs_shape.size()) + throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); + ir::type::block_shapes_t ret_shape; + for (size_t i = 0; i < lhs_shape.size(); ++i) { + unsigned left = lhs_shape[i]; + unsigned right = rhs_shape[i]; + if (left == 1) + ret_shape.push_back(right); + else if (right == 1) + ret_shape.push_back(left); + else if (left == right) + ret_shape.push_back(left); + else + throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + + ": " + std::to_string(left) + " and " + std::to_string(right)); + } + if (lhs_shape != ret_shape) + lhs = builder->create_broadcast(lhs, ret_shape); + if (rhs_shape != ret_shape) + rhs = builder->create_broadcast(rhs, ret_shape); + } + return std::make_tuple(lhs, rhs); +} + +ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty()) + return cast(input, dst_ty, builder); + // Bitcast + int src_bits = src_sca_ty->get_primitive_size_in_bits(); + int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); + if( src_bits!= dst_bits) + throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + + "to data-type of size " + std::to_string(dst_bits)); + return builder->create_cast(ir::BitCast, input, dst_ty); +} + +ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { + ir::type *src_ty = input->get_type(); + if (src_ty->is_block_ty()) + dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); + if(src_ty == dst_ty) + return input; + ir::type *src_sca_ty = src_ty->get_scalar_ty(); + ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); + // + if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || + (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ + return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); + } + // FP Truncation + bool truncate_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); + if (truncate_fp) + return builder->create_fp_trunc(input, dst_ty); + // FP Extension + bool ext_fp = src_sca_ty->is_floating_point_ty() && + dst_sca_ty->is_floating_point_ty() && + src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); + if (ext_fp) + return builder->create_fp_ext(input, dst_ty); + // Int cast + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && + (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || + src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { + bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); + return builder->create_int_cast(input, dst_ty, sign_extend); + } + // Float -> Int + if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ + if(dst_sca_ty->is_bool_ty()) + return builder->create_fp_to_ui(input, dst_ty); + else + return builder->create_fp_to_si(input, dst_ty); + } + // int -> Float + if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ + if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) + return builder->create_ui_to_fp(input, dst_ty); + else + return builder->create_si_to_fp(input, dst_ty); + } + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ + int bitwidth = dst_sca_ty->get_integer_bitwidth(); + if(bitwidth == 64) + return builder->create_cast(ir::PtrToInt, input, dst_ty); + if(bitwidth == 1) + return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), + builder->get_int64(0), + builder); + } + if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::IntToPtr, input, dst_ty); + // Ptr -> Ptr + if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) + return builder->create_cast(ir::BitCast, input, dst_ty); + // * -> Bool + if (dst_sca_ty->is_bool_ty()) { + if (src_sca_ty->is_pointer_ty()) + input = cast(input, builder->get_int64_ty(), builder); + ir::value *other = builder->get_int64(0); + if (src_ty->is_bool_ty()) + other = builder->create_splat(other, src_ty->get_block_shapes()); + return builder->create_icmpNE(input, other); + } + throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); +} + +//===----------------------------------------------------------------------===// +// Memory Operators +//===----------------------------------------------------------------------===// + +ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask) + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + if(other) + other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); + } + if(other) + other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } + // cache modifier + load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default + if (!cache_modifier.empty()) { + if (cache_modifier == ".ca") + cache = load_inst::CA; + else if (cache_modifier == ".cg") + cache = load_inst::CG; + else + throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); + } + // eviction policy + load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default + if(!eviction_policy.empty()){ + if (eviction_policy == "evict_last") + eviction = load_inst::EVICT_LAST; + else if(eviction_policy == "evict_first") + eviction = load_inst::EVICT_FIRST; + else + throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); + } + + + if (!mask && !other) + return builder->create_load(ptr, cache, eviction, is_volatile); + if (!mask) + throw std::runtime_error("`other` cannot be provided without `mask`"); + auto shape = ptr->get_type()->get_block_shapes(); + if(!other){ + other = ir::undef_value::get(elt_ty); + if(ptr->get_type()->is_block_ty()) + other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); + } + return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); +} + +ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()) + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + if(mask) + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); + ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); + // treat bool* as int8* + if(elt_ty == builder->get_int1_ty()){ + elt_ty = builder->get_int8_ty(); + ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); + ptr = dispatch::cast(ptr, ptr_ty, builder); + } + // cast to target data-type + val = dispatch::cast(val, elt_ty, builder); + if (!mask) + return builder->create_store(ptr, val); + if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) + throw semantic_error("Mask must have boolean scalar type"); + return builder->create_masked_store(ptr, val, mask); +} + +ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ + return builder->create_atomic_cas(ptr, cmp, val); +} + +void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ + if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) + throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); + if(ptr->get_type()->is_block_ty()){ + if(mask){ + mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); + } + if(val){ + val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); + } + } + val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); + if(!mask){ + mask = builder->get_int1(true); + if(ptr->get_type()->is_block_ty()) + mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); + } +} + +ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_max for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); + } + } + // for float + // return atomic_smax(i_ptr, i_val) if val >= 0 + // return atomic_umin(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + // direct call to atomic_min for integers + if(sca_ty->is_integer_ty()) { + if (sca_ty->is_integer_signed()) { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); + } else { + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); + } + } + // for float + // return atomic_smin(i_ptr, i_val) if val >= 0 + // return atomic_umax(i_ptr, i_val) if val < 0 + ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); + ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); + ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); + ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); + ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); + ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); + return where(pos, pos_ret, neg_ret, builder); +} + +ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; + return builder->create_atomic_rmw(op, ptr, val, mask); +} + +ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); +} + +ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); +} + +ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); +} + +ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ + atom_red_typechecking(ptr, val, mask, builder); + ir::type* sca_ty = val->get_type()->get_scalar_ty(); + return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); +} + +//===----------------------------------------------------------------------===// +// Linear Algebra +//===----------------------------------------------------------------------===// + +ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { + ir::value *_0 = nullptr; + if (lhs->get_type()->is_int_or_tileint_ty()) + _0 = builder->get_int32(0); + else + _0 = builder->get_float32(0); + unsigned M = lhs->get_type()->get_block_shapes()[0]; + unsigned N = rhs->get_type()->get_block_shapes()[1]; + _0 = builder->create_splat(_0, {M, N}); + bool _allow_tf32 = allow_tf32->get_value() != 0; + return builder->create_dot(lhs, rhs, _0, _allow_tf32); +} + + +//===----------------------------------------------------------------------===// +// Indexing +//===----------------------------------------------------------------------===// + +ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ + condition = dispatch::cast(condition, builder->get_int1_ty(), builder); + if(condition->get_type()->is_block_ty()){ + x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); + y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); + } + ir::type* x_ty = x->get_type()->get_scalar_ty(); + ir::type* y_ty = y->get_type()->get_scalar_ty(); + ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); + x = dispatch::cast(x, ty, builder); + y = dispatch::cast(y, ty, builder); + return builder->create_select(condition, x, y); +} + + +//===----------------------------------------------------------------------===// +// Reductions +//===----------------------------------------------------------------------===// + +ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, + ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + // input is extended to 32-bits if necessary + // this increases numerical accuracy and can be done pretty much for free + // on GPUs + if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) + input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); + if (scalar_ty->is_floating_point_ty()) + return builder->create_reduce(input, FLOAT_OP, axis); + else if (scalar_ty->is_integer_ty()) + return builder->create_reduce(input, INT_OP, axis); + throw_unreachable(name); +} + +ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); +} + +ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); +} + +ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); +} + +ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { + ir::type *scalar_ty = input->get_type()->get_scalar_ty(); + if (!scalar_ty->is_integer_ty()) + throw semantic_error("xor_sum only supported for integers"); + return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); +} + + +//===----------------------------------------------------------------------===// +// Math +//===----------------------------------------------------------------------===// + +ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { + binary_op_type_checking(x, y, builder); + return builder->insert(umulhi_inst::create(x, y)); +} + +ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { + return builder->create_exp(x); +} + +ir::value *dispatch::log(ir::value *x, ir::builder *builder) { + return builder->create_log(x); +} + +ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { + return builder->create_cos(x); +} + +ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { + return builder->create_sin(x); +} + +ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { + return builder->create_sqrt(x); +} + + +// + +ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("multiple_of"); + i->set_metadata(ir::metadata::multiple_of, value); + return i; +} + +ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ + ir::instruction* i = dynamic_cast(x); + if(!i) + throw_unreachable("max_contiguous"); + i->set_metadata(ir::metadata::max_contiguous, value); + return i; +} + +ir::value *dispatch::debug_barrier(ir::builder *builder) { + return builder->create_barrier(); +} + + +} +} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 39bd945bc..c225b315f 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -312,8 +312,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); + (arg_bits > dst_bits ? cast_op_t::Trunc : + (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); return create(op, arg, ty, name, next); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index a37d3048f..33b39de3a 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,6 +9,146 @@ namespace triton{ namespace ir{ +/* Module */ +module::module(const std::string &name, builder &builder) + : name_(name), builder_(builder) { + sealed_blocks_.insert(nullptr); +} + +ir::builder& module::get_builder() { + return builder_; +} + +void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ + values_[val_key_t{name, block}] = value; + auto it = metadatas_.find(name); + if(auto *x = dynamic_cast(value)) + if(it != metadatas_.end()){ + x->set_metadata(it->second.first, it->second.second); + } +// value->set_name(name); +} + +void module::set_value(const std::string& name, ir::value *value){ + return set_value(name, builder_.get_insert_block(), value); +} + +void module::set_const(const std::string& name){ + const_.insert(name); +} + +void module::set_continue_fn(std::function fn) { + continue_fn_ = fn; +} + +std::function module::get_continue_fn() { + return continue_fn_; +} + +ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ + basic_block::iterator insert = block->get_first_non_phi(); + if(insert != block->end()){ + builder_.set_insert_point(insert); + } + ir::phi_node *res = builder_.create_phi(ty, num_values); + if(insert != block->end()) + builder_.set_insert_point(block); + return res; +} + +ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ + // find non-self references + std::set non_self_ref; + std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), + [phi](ir::value* op){ return op != phi && op; }); + // non-trivial + if(non_self_ref.size() != 1) + return phi; + // unique value or self-reference + ir::value *same = *non_self_ref.begin(); + assert(same != nullptr); + phi->replace_all_uses_with(same); + phi->erase_from_parent(); + std::set users = phi->get_users(); + for(ir::user* u: users) + if(auto *uphi = dynamic_cast(u)) + if(uphi != phi) + try_remove_trivial_phis(uphi); + return same; +} + + +ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ + // already initialized + if(phi->get_num_operands()) + return phi; + ir::basic_block *block = phi->get_parent(); + for(ir::basic_block *pred: block->get_predecessors()){ + ir::value *value = get_value(name, pred); + phi->add_incoming(value, pred); + } + return phi; +} + +ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { + ir::value *result; + bool is_const = const_.find(name) != const_.end(); + auto &preds = block->get_predecessors(); + ir::type *ty = types_.at(name); + if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ + incomplete_phis_[block][name] = make_phi(ty, 1, block); + result = (ir::value*)incomplete_phis_[block][name]; + } + else if(preds.size() <= 1){ + bool has_pred = preds.size(); + result = get_value(name, has_pred?preds.front():nullptr); + } + else{ + ir::phi_node* phi = make_phi(ty, 1, block); + set_value(name, block, phi); + result = add_phi_operands(name, phi); + if(auto *phi = dynamic_cast(result)) + result = try_remove_trivial_phis(phi); + } + if(auto *phi = dynamic_cast(result)){ + result = try_remove_trivial_phis(phi); + } + set_value(name, block, result); + return result; +} + +ir::value *module::get_value(const std::string& name, ir::basic_block *block) { + ir::basic_block* save_block = builder_.get_insert_block(); + ir::basic_block::iterator save_pt = builder_.get_insert_point(); + val_key_t key(name, block); + if(values_.find(key) != values_.end()){ + return values_.at(key); + } + ir::value *result = get_value_recursive(name, block); + builder_.set_insert_point(save_block); + if(save_pt != save_block->end()) + builder_.set_insert_point(save_pt); + return result; +} + +ir::value *module::get_value(const std::string& name) { + return get_value(name, builder_.get_insert_block()); +} + +const std::string& module::get_name() { + return name_; +} + +void module::seal_block(ir::basic_block *block){ + for(auto &x: incomplete_phis_[block]){ + add_phi_operands(x.first, x.second); + if(get_value(x.first) == x.second) + set_value(x.first, try_remove_trivial_phis(x.second)); + } + sealed_blocks_.insert(block); + incomplete_phis_[block].clear(); +} + /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 056ae99e6..7e4e4e5d7 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } +signedness type::get_integer_signedness() const +{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } + +bool type::is_integer_signed() const { + if (id_ != IntegerTyID) { + throw std::logic_error("type is " + repr() + ", not integer"); + } + return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; +} + unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } +integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } +integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } +integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } +integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index b66761ec3..9e53cc341 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,6 +3,7 @@ #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" +#include "triton/ir/dispatch.h" #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -11,12 +12,10 @@ #include #include #include -#include #include #include "Python.h" #include #include -#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -542,6 +541,84 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); } +/*****************************************************************************/ +/* User-facing language features */ +/*****************************************************************************/ + +void init_triton_frontend(py::module &&m) { + using ret = py::return_value_policy; + + // programming model + m.def("program_id", &ir::dispatch::program_id, ret::reference); + m.def("num_programs", &ir::dispatch::num_programs, ret::reference); + // binary + m.def("add", &ir::dispatch::add, ret::reference); + m.def("sub", &ir::dispatch::sub, ret::reference); + m.def("mul", &ir::dispatch::mul, ret::reference); + m.def("truediv", &ir::dispatch::truediv, ret::reference); + m.def("floordiv", &ir::dispatch::floordiv, ret::reference); + m.def("fdiv", &ir::dispatch::fdiv, ret::reference); + m.def("mod", &ir::dispatch::mod, ret::reference); + m.def("and_", &ir::dispatch::and_, ret::reference); + m.def("or_", &ir::dispatch::or_, ret::reference); + m.def("xor_", &ir::dispatch::xor_, ret::reference); + m.def("lshr", &ir::dispatch::lshr, ret::reference); + m.def("shl", &ir::dispatch::shl, ret::reference); + // unary + m.def("plus", &ir::dispatch::plus, ret::reference); + m.def("minus", &ir::dispatch::minus, ret::reference); + m.def("invert", &ir::dispatch::invert, ret::reference); + // comparison + m.def("greater_than", &ir::dispatch::greater_than, ret::reference); + m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); + m.def("less_than", &ir::dispatch::less_than, ret::reference); + m.def("less_equal", &ir::dispatch::less_equal, ret::reference); + m.def("equal", &ir::dispatch::equal, ret::reference); + m.def("not_equal", &ir::dispatch::not_equal, ret::reference); + // block creation + m.def("arange", &ir::dispatch::arange, ret::reference); + m.def("zeros", &ir::dispatch::zeros, ret::reference); + // type manipuatation + m.def("cat", &ir::dispatch::cat, ret::reference); + m.def("reshape", &ir::dispatch::reshape, ret::reference); + typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); + typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); + m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); + m.def("bitcast", &ir::dispatch::bitcast, ret::reference); + m.def("cast", &ir::dispatch::cast, ret::reference); + // memory + m.def("load", &ir::dispatch::load, ret::reference); + m.def("store", &ir::dispatch::store, ret::reference); + m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); + m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); + m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); + m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); + m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); + m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); + m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); + m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); + // linear algebra + m.def("dot", &ir::dispatch::dot, ret::reference); + // indexing + m.def("where", &ir::dispatch::where, ret::reference); + // reduction + m.def("min", &ir::dispatch::min, ret::reference); + m.def("max", &ir::dispatch::max, ret::reference); + m.def("sum", &ir::dispatch::sum, ret::reference); + m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); + // math + m.def("umulhi", &ir::dispatch::umulhi, ret::reference); + m.def("exp", &ir::dispatch::exp, ret::reference); + m.def("log", &ir::dispatch::log, ret::reference); + m.def("cos", &ir::dispatch::cos, ret::reference); + m.def("sin", &ir::dispatch::sin, ret::reference); + m.def("sqrt", &ir::dispatch::sqrt, ret::reference); + // internal (debugging only) + m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); + m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); + m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); +} /*****************************************************************************/ /* Python bindings for triton::ir */ @@ -551,86 +628,16 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; - py::enum_(m, "CACHE_MODIFIER") - .value("NONE", ir::load_inst::NONE) - .value("CA", ir::load_inst::CA) - .value("CG", ir::load_inst::CG) - .export_values(); - - py::enum_(m, "EVICTION_POLICY") - .value("NORMAL", ir::load_inst::NORMAL) - .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) - .value("EVICT_LAST", ir::load_inst::EVICT_LAST) - .export_values(); - - py::enum_(m, "REDUCE_OP") - .value("ADD", ir::reduce_inst::ADD) - .value("FADD", ir::reduce_inst::FADD) - .value("MIN", ir::reduce_inst::MIN) - .value("MAX", ir::reduce_inst::MAX) - .value("FMIN", ir::reduce_inst::FMIN) - .value("FMAX", ir::reduce_inst::FMAX) - .value("XOR", ir::reduce_inst::XOR); - - py::enum_(m, "ATOMIC_OP") - .value("ADD", ir::atomic_rmw_op_t::Add) - .value("FADD", ir::atomic_rmw_op_t::FAdd) - .value("AND", ir::atomic_rmw_op_t::And) - .value("OR", ir::atomic_rmw_op_t::Or) - .value("XOR", ir::atomic_rmw_op_t::Xor) - .value("XCHG", ir::atomic_rmw_op_t::Xchg) - .value("MAX", ir::atomic_rmw_op_t::Max) - .value("MIN", ir::atomic_rmw_op_t::Min) - .value("UMIN", ir::atomic_rmw_op_t::UMin) - .value("UMAX", ir::atomic_rmw_op_t::UMax); - py::class_(m, "context") .def(py::init<>()); - py::class_(m, "value") - .def("multiple_of", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::multiple_of, val); - } else - throw std::runtime_error("multiple_of"); - }) - .def("max_contiguous", [](ir::value *self, int val) { - if (auto *instr = dynamic_cast(self)) { - instr->set_metadata(ir::metadata::max_contiguous, val); - } else - throw std::runtime_error("max_contiguous"); - }) - .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { - if (auto *instr = dynamic_cast(self)) - instr->set_fdiv_ieee_rounding(val); - else - throw std::runtime_error("set_fdiv_ieee_rounding"); - }) - .def("is_phi", [](ir::value *self) { - if (auto *pn = dynamic_cast(self)) - return true; - return false; - }) - .def("ops", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) { - return instr->ops(); - } - throw std::runtime_error("cannot use ops()"); - }) - .def("replace_all_uses_with", &ir::value::replace_all_uses_with) - .def("erase_from_parent", [](ir::value *self) { - if (auto *instr = dynamic_cast(self)) - return instr->erase_from_parent(); - throw std::runtime_error("cannot use erase_from_parent"); - }) - .def_property("name", &ir::value::get_name, &ir::value::set_name) - .def_property_readonly("type", &ir::value::get_type); + auto value = py::class_(m, "value"); + value.def_property("name", &ir::value::get_name, &ir::value::set_name); + value.def_property_readonly("type", &ir::value::get_type); py::class_(m, "user"); - py::class_(m, "constant") - .def("get_null_value", &ir::constant::get_null_value, ret::reference) - .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); + py::class_(m, "constant"); py::class_(m, "undef") .def("get", &ir::undef_value::get, ret::reference); @@ -641,17 +648,16 @@ void init_triton_ir(py::module &&m) { .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value) - .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); + .def_property_readonly("value", &ir::constant_fp::get_value); - py::class_(m, "instruction") - .def("get_parent", [](ir::instruction *self) { - return self->get_parent(); - }, ret::reference); - py::class_(m, "phi_node") - .def("add_incoming", &ir::phi_node::add_incoming); + py::class_(m, "instruction"); + py::class_(m, "phi_node"); py::class_(m, "type") + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) @@ -666,38 +672,34 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) + .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) + .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) + .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) + .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) - .def("get_block_shapes", &ir::type::get_block_shapes) - - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) .def("is_void", &ir::type::is_void_ty) - .def("is_bool", &ir::type::is_bool_ty) .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_fp16_ty) .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) - .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) + .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) + .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) + .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) + .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference) - .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) - .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); + .def_property_readonly("context", &ir::type::get_context, ret::reference); py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) - .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); py::class_(m, "function_type"); py::class_(m, "integer_type"); @@ -707,15 +709,16 @@ void init_triton_ir(py::module &&m) { py::class_(m, "module") .def(py::init()) - .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { - const auto metadatas = self->get_metadatas(); - auto it = metadatas.find(name); - if (it != metadatas.end()) - if (auto *instr = dynamic_cast(value)) { - instr->set_metadata(it->second.first, it->second.second); - } - }) - .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); + .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) + .def("seal_block", &ir::module::seal_block) + .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) + .def("set_type", &ir::module::set_type) + .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) + .def("get_values", &ir::module::get_values, ret::reference) + .def("set_values", &ir::module::set_values) + .def("get_types", &ir::module::get_types, ret::reference) + .def("set_types", &ir::module::set_types) + .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; py::enum_(m, "attribute_kind") @@ -739,13 +742,6 @@ void init_triton_ir(py::module &&m) { py::class_(m, "basic_block") .def("create", &ir::basic_block::create, ret::reference) - .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) - .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { - ir::basic_block::iterator it = self->get_first_non_phi(); - if (it == self->end()) - return nullptr; - return *it; - }, ret::reference) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "builder", py::dynamic_attr()) @@ -756,162 +752,17 @@ void init_triton_ir(py::module &&m) { .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) - // insertion block/point, insert points are represented as (*bb, *instr) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) - .def("get_insert_point", [](ir::builder *self) { - ir::basic_block *bb = self->get_insert_block(); - ir::basic_block::iterator it = self->get_insert_point(); - ir::instruction *instr = it == bb->end() ? nullptr : *it; - return std::make_pair(bb, instr); - }, ret::reference) - .def("set_insert_point", [](ir::builder *self, std::pair pt) { - ir::basic_block *bb = pt.first; - ir::instruction *instr = pt.second; - if (instr) { - if (bb != instr->get_parent()) - throw std::runtime_error("invalid insertion point, instr not in bb"); - self->set_insert_point(instr); - } else { - assert(bb); - self->set_insert_point(bb); - } - }) - // Constants + // constants .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) - .def("get_uint32", &ir::builder::get_int32, ret::reference) - .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) - .def("get_uint64", &ir::builder::get_int64, ret::reference) + .def("get_int32", &ir::builder::get_int32, ret::reference) + .def("get_int64", &ir::builder::get_int64, ret::reference) + .def("get_uint32", &ir::builder::get_uint32, ret::reference) + .def("get_uint64", &ir::builder::get_uint64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference) - // Types - .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) - .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference) - .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference) - .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) - .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference) - .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) - .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) - .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) - .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) - .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) - .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) - // terminator instructions - .def("create_br", &ir::builder::create_br, ret::reference) - .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) - .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) - // Cast instructions - .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) - .def("create_cast", &ir::builder::create_cast, ret::reference) - .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) - .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) - .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) - .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) - .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) - .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) - .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) - .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) - .def("create_downcast", &ir::builder::create_downcast, ret::reference) - // phi - .def("create_phi", &ir::builder::create_phi, ret::reference) - // Binary instructions - .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) - .def("create_fmul", &ir::builder::create_fmul, ret::reference) - .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) - .def("create_frem", &ir::builder::create_frem, ret::reference) - .def("create_fadd", &ir::builder::create_fadd, ret::reference) - .def("create_fsub", &ir::builder::create_fsub, ret::reference) - .def("create_mul", &ir::builder::create_mul, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) - .def("create_udiv", &ir::builder::create_udiv, ret::reference) - .def("create_srem", &ir::builder::create_srem, ret::reference) - .def("create_urem", &ir::builder::create_urem, ret::reference) - .def("create_add", &ir::builder::create_add, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_sub", &ir::builder::create_sub, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_shl", &ir::builder::create_shl, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_lshr", &ir::builder::create_lshr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - .def("create_ashr", &ir::builder::create_ashr, ret::reference, - py::arg("lhs"), py::arg("rhs"), - py::arg("has_nuw")=false, py::arg("has_nsw")=false) - // GEP - .def("create_gep", &ir::builder::create_gep, ret::reference) - // Comparison (int) - .def("create_icmp", &ir::builder::create_icmp, ret::reference) - .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) - .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) - .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) - .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) - .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) - .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) - .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) - .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) - .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) - .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) - // Comparison (float) - .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) - .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) - .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) - .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) - .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) - .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) - .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) - .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) - .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) - .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) - .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) - .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) - .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) - // Logical - .def("create_and", &ir::builder::create_and, ret::reference) - .def("create_xor", &ir::builder::create_xor, ret::reference) - .def("create_or", &ir::builder::create_or, ret::reference) - // Input/Output - .def("create_load", &ir::builder::create_load, ret::reference) - .def("create_store", &ir::builder::create_store, ret::reference) - .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) - .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) - // Block instruction - .def("create_splat", &ir::builder::create_splat, ret::reference) - .def("create_reshape", &ir::builder::create_reshape, ret::reference) - .def("create_cat", &ir::builder::create_cat, ret::reference) - .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) - // atomic - .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) - .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) - - // Built-in instruction - .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) - .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) - .def("create_exp", &ir::builder::create_exp, ret::reference) - .def("create_cos", &ir::builder::create_cos, ret::reference) - .def("create_sin", &ir::builder::create_sin, ret::reference) - .def("create_log", &ir::builder::create_log, ret::reference) - .def("create_dot", &ir::builder::create_dot, ret::reference) - .def("create_trans", &ir::builder::create_trans, ret::reference) - .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) - .def("create_reduce", &ir::builder::create_reduce, ret::reference) - .def("create_select", &ir::builder::create_select, ret::reference) - // Intrinsics - // These have no place in the IR, and hopefully they can be removed at some point - .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) - .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) - .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) - .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) - .def("create_barrier", &ir::builder::create_barrier, ret::reference) - .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) - .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); + .def("get_range", &ir::builder::get_range, ret::reference); } void init_triton(py::module &m) { @@ -919,4 +770,5 @@ void init_triton(py::module &m) { init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); + init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index f30b203bb..1df3a0b49 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -37,7 +37,7 @@ matmul_data = { (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.695}, + (2048, 2048, 2048): {'float16': 0.680}, (4096, 4096, 4096): {'float16': 0.831}, (8192, 8192, 8192): {'float16': 0.849}, # tall-skinny diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3561f7af4..a49b47585 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,4 +1,5 @@ # flake8: noqa: F821,F841 +import copy import itertools import re from typing import Optional, Union @@ -584,6 +585,7 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) + print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -991,6 +993,27 @@ def test_noop(device='cuda'): kernel[(1, )](x) +@pytest.mark.parametrize("value, value_type", [ + (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device='cuda') + pgm = kernel[(1, )](value, x) + + # Parse out the type of the 'VALUE' parameter from the Triton IR. + triton_ir = pgm.asm['ttir'] + ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) + ir_value_type = None if ir_value_match is None else ir_value_match.group(1) + assert ir_value_type == value_type + + @pytest.mark.parametrize( "value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d866d6983..8ac01bcc8 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,5 +1,4 @@ import os -import re import shutil import pytest @@ -103,30 +102,3 @@ def test_specialize(mode): for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target - - -@pytest.mark.parametrize("value, value_type", [ - (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), - (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), - (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') -]) -def test_value_specialization(value: int, value_type: str, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - cache_str = None - - def get_cache_str(*args, **kwargs): - nonlocal cache_str - cache_str = kwargs['key'].split('-') - triton.code_gen.JITFunction.cache_hook = get_cache_str - reset_tmp_dir() - x = torch.tensor([3.14159], device='cuda') - kernel[(1, )](value, x) - triton.code_gen.JITFunction.cache_hook = None - - cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) - spec_type = None if cache_str_match is None else cache_str_match.group(1) - assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 37ba46efc..f9982939c 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,8 +6,7 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - JITFunction, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 23d460f29..cb705aaa6 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import ast import builtins import functools @@ -13,7 +11,7 @@ import tempfile import textwrap import time import warnings -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict import torch from filelock import FileLock @@ -24,13 +22,48 @@ from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): + def get_value(self, name): + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + if isinstance(ret, triton.language.block): + handle = self.module.get_value(name) + return triton.language.block(handle) + return ret + + def set_value(self, name, value): + if isinstance(value, _triton.ir.value): + value = triton.language.block(value) + if isinstance(value, triton.language.block): + self.module.set_value(name, value.handle) + self.module.set_type(name, value.handle.type) + self.lscope[name] = value + + def is_triton_object(self, value): + return isinstance(value, triton.language.block) + + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.last_ret = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + return stmts and isinstance(stmt, ast.Return) + def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) self.prototype = prototype self.gscope = gscope self.lscope = dict() - self.is_arg_lscope = dict() # name => is_arg: {str: bool} self.attributes = attributes self.constants = constants self.kwargs = kwargs @@ -44,146 +77,6 @@ class CodeGenerator(ast.NodeVisitor): 'isinstance': isinstance, 'getattr': getattr, } - # SSA-construction - # [name, bb] => triton.language.tensor - self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} - # bb => {name => phi} - self.incomplete_phis = {} - self.sealed_blocks: Set[_triton.ir.basic_block] = set() - - def get_value(self, name): - ''' This function: - 1. make sure `name` is defined - 2. if `name` is triton.language.tensor, get stored tensor by calling - `self._get_tensor()` - ''' - # search node.id in local scope - ret = None - if name in self.lscope: - ret = self.lscope[name] - # search node.id in global scope - elif name in self.gscope: - ret = self.gscope[name] - # search node.id in builtins - elif name in self.builtins: - ret = self.builtins[name] - else: - raise ValueError(f'{name} is not defined') - if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]: - return self._get_tensor(name) - return ret - - def set_value(self, name: str, - value: Union[triton.language.tensor, triton.language.constexpr], - is_arg: bool = False) -> None: - ''' This function: - called by visit_Assign() & visit_FuncDef() to store left value (lvalue) - 1. record local defined name (FIXME: should consider control flow) - 2. store tensor in self.lvalue - ''' - self.lscope[name] = value - # if this value is an argument, we don't need to create phis for it - self.is_arg_lscope[name] = is_arg - if isinstance(value, triton.language.tensor) and not is_arg: - self._set_value(name, self.builder.get_insert_block(), value) - - # - # SSA-construction - # - def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: - if not bb: - bb = self.builder.get_insert_block() - # local value numbering - if (name, bb) in self.lvalues: - return self.lvalues[(name, bb)] - # global value numbering - saved_insert_point = self.builder.get_insert_point() - result = self._get_tensor_recursive(name, bb) - self.builder.set_insert_point(saved_insert_point) - return result - - def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: - preds = bb.get_predecessors() - type = self.lscope[name].type - # some preds haven't been filled, create a phi as a proxy of the value - if bb not in self.sealed_blocks: - result = self._make_phi(type, len(preds), bb) - if bb in self.incomplete_phis: - self.incomplete_phis[bb][name] = result - else: - self.incomplete_phis[bb] = {name: result} - elif len(preds) == 1: - # one predecessor: no phi needed, try get value from pred - result = self._get_tensor(name, preds[0]) - else: # multiple preds - assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)' - phi = self._make_phi(type, len(preds), bb) - self._set_value(name, bb, phi) - result = self._add_phi_operands(name, phi) - self._set_value(name, bb, result) - return result - - # returns a new phi tensor, which encausulate an ir.phi_node - def _make_phi(self, - type: triton.language.dtype, - num_values: int, - bb: _triton.ir.basic_block) -> triton.language.tensor: - instr = bb.get_first_non_phi() - self.builder.set_insert_point((bb, instr)) - ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) - if instr: - self.builder.set_insert_block(bb) - return triton.language.tensor(ir_phi, type) - - # complete a phi node. (TODO: rename this as _complete_phis?) - # Note: since we try to remove tryival phi, the return tensor might not be a phi - def _add_phi_operands(self, name: str, - phi: triton.language.tensor) -> triton.language.tensor: - bb = phi.handle.get_parent() - for pred in bb.get_predecessors(): - v = self._get_tensor(name, pred) - phi.handle.add_incoming(v.handle, pred) - phi = self._try_remove_trivial_phi(phi) - return phi - - def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: - self.lvalues[(name, bb)] = value - # TODO: why we need this? - self.module.set_instr_metadata(name, value.handle) - - def _seal_block(self, bb: _triton.ir.basic_block): - # complete all incomplete phis - if bb in self.incomplete_phis: - for name, phi in self.incomplete_phis[bb].items(): - result = self._add_phi_operands(name, phi) - # it's possible that this phi is trivial - if self._get_tensor(name, bb).handle == phi.handle: - self._set_value(name, bb, result) - del self.incomplete_phis[bb] - self.sealed_blocks.add(bb) - - def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: - unique_handles = {op for op in phi.handle.ops() if op != phi.handle} - if len(unique_handles) != 1: # non-trivial phi - return phi - v = unique_handles.pop() - phi.handle.replace_all_uses_with(v) - phi.handle.erase_from_parent() - # TODO: remove trivial phis recursively - return triton.language.tensor(v, phi.type) - - def is_triton_tensor(self, value): - return isinstance(value, triton.language.tensor) - - # - # AST visitor - # - def visit_compound_statement(self, stmts): - for stmt in stmts: - self.last_ret = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -220,7 +113,7 @@ class CodeGenerator(ast.NodeVisitor): if inline: pass else: - fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) + fn = self.module.get_or_insert_function(node.name, self.prototype) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -237,17 +130,17 @@ class CodeGenerator(ast.NodeVisitor): attr = _triton.ir.attribute(attr, self.attributes[i]) fn.add_attr(idx + 1, attr) fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) + arg_values.append(fn.args[idx]) idx += 1 for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value, is_arg=True) + self.set_value(arg_name, arg_value) if inline: self.visit_compound_statement(node.body) return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self._seal_block(entry) + self.module.seal_block(entry) self.builder.set_insert_block(entry) # visit function body self.visit_compound_statement(node.body) @@ -294,12 +187,11 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): - # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value - if not isinstance(value, triton.language.tensor): - value = triton.language.core._to_tensor(value, self.builder) + if not isinstance(value, triton.language.block): + value = triton.language.core._to_ir(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -328,9 +220,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): + if isinstance(lhs, triton.language.core.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): + if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value fn = { ast.Add: '__add__', @@ -346,9 +238,9 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -356,15 +248,15 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if isinstance(cond, triton.language.tensor): + if isinstance(cond, triton.language.block): cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self._seal_block(then_bb) + self.module.seal_block(then_bb) if else_bb: - self._seal_block(else_bb) + self.module.seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -379,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self._seal_block(endif_bb) + self.module.seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -404,9 +296,9 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.constexpr): + if isinstance(lhs, triton.language.core.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): + if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) @@ -420,9 +312,9 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): + elif self.is_triton_object(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -433,21 +325,21 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): + if isinstance(op, triton.language.core.constexpr): op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if self.is_triton_tensor(op): + if self.is_triton_object(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) def continue_fn(): cond = self.visit(node.test) @@ -458,9 +350,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -470,7 +362,7 @@ class CodeGenerator(ast.NodeVisitor): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) - if self.is_triton_tensor(lhs): + if self.is_triton_object(lhs): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] @@ -513,8 +405,8 @@ class CodeGenerator(ast.NodeVisitor): step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) def continue_fn(): self.visit(step_node) @@ -529,9 +421,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) + self.module.seal_block(stop_bb) + self.module.seal_block(loop_bb) + self.module.seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -559,7 +451,7 @@ class CodeGenerator(ast.NodeVisitor): args = [self.visit(arg) for arg in node.args] if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) - if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ + if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): @@ -699,7 +591,7 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] - if isinstance(obj, triton.language.constexpr): + if isinstance(obj, triton.language.core.constexpr): obj = obj.value if isinstance(obj, int): if -2**31 <= obj < 2**31: @@ -731,34 +623,34 @@ class Kernel: return 'scalar', name @staticmethod - def _to_triton_ir(obj): + def _to_triton_ir(context, obj): which, name = obj type_map = { - 'I': triton.language.int32, - 'L': triton.language.int64, - 'f': triton.language.float32, - 'B': triton.language.int1, - 'f8': triton.language.float8, - 'f16': triton.language.float16, - 'bf16': triton.language.bfloat16, - 'f32': triton.language.float32, - 'f64': triton.language.float64, - 'i1': triton.language.int1, - 'i8': triton.language.int8, - 'i16': triton.language.int16, - 'i32': triton.language.int32, - 'i64': triton.language.int64, - 'u8': triton.language.uint8, - 'u16': triton.language.uint16, - 'u32': triton.language.uint32, - 'u64': triton.language.uint64, + 'I': _triton.ir.type.get_int32, + 'L': _triton.ir.type.get_int64, + 'f': _triton.ir.type.get_fp32, + 'B': _triton.ir.type.get_int1, + 'f8': _triton.ir.type.get_fp8, + 'f16': _triton.ir.type.get_fp16, + 'bf16': _triton.ir.type.get_bf16, + 'f32': _triton.ir.type.get_fp32, + 'f64': _triton.ir.type.get_fp64, + 'i1': _triton.ir.type.get_int1, + 'i8': _triton.ir.type.get_int8, + 'i16': _triton.ir.type.get_int16, + 'i32': _triton.ir.type.get_int32, + 'i64': _triton.ir.type.get_int64, + 'u8': _triton.ir.type.get_uint8, + 'u16': _triton.ir.type.get_uint16, + 'u32': _triton.ir.type.get_uint32, + 'u64': _triton.ir.type.get_uint64, } # convert torch.Tensor to Triton IR pointers if which == 'ptr': - elt_ty = type_map[name] - return triton.language.pointer_type(elt_ty, 1) + elt_ty = type_map[name](context) + return _triton.ir.type.make_ptr(elt_ty, 1) # default path returns triton.ir.type directly - return type_map[name] + return type_map[name](context) @staticmethod def pow2_divisor(N): @@ -1038,31 +930,25 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree - # Called by CodeGenerator.visit_Call() def __call__(self, *args, generator: CodeGenerator, **kwargs): try: from inspect import getcallargs arg_values = getcallargs(self.fn, *args, **kwargs) arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.tensor) + arg_values = [arg if isinstance(arg, triton.language.block) else triton.language.constexpr(arg) for arg in arg_values] - # Record values in the caller (parent scope) gscope = generator.gscope.copy() lscope = generator.lscope.copy() - - # TODO: clear values other than args - lvalues = generator.lvalues.copy() - # types = generator.module.get_types().copy() + values = generator.module.get_values().copy() + types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope - - generator.lvalues = lvalues - # generator.module.set_types(types) - + generator.module.set_values(values) + generator.module.set_types(types) return ret except Exception as e: node = generator.last_node @@ -1147,9 +1033,9 @@ class JITFunction: # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) + arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] + ret_type = _triton.ir.type.get_void(context) + prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 81b9fe790..df25e59fb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,36 +1,63 @@ -from __future__ import annotations - -from enum import Enum from functools import wraps -from typing import List import triton -from . import semantic -from triton._C.libtriton.triton import ir +from triton._C.libtriton.triton import frontend, ir -def _to_tensor(x, builder): +# convert block/dtype to ir values +def _to_ir(x, builder): if isinstance(x, bool): - return tensor(builder.get_int1(x), int1) - # Note: compile-time const integers are represented by unsigned values + return builder.get_int1(x) elif isinstance(x, int): if -2**31 <= x < 2**31: - return tensor(builder.get_int32(x), int32) + return builder.get_int32(x) elif 2**31 <= x < 2**32: - return tensor(builder.get_uint32(x), uint32) + return builder.get_uint32(x) elif -2**63 <= x < 2**63: - return tensor(builder.get_int64(x), int64) + return builder.get_int64(x) elif 2**63 <= x < 2**64: - return tensor(builder.get_uint64(x), uint64) + return builder.get_uint64(x) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - return tensor(builder.get_float32(x), float32) + return builder.get_float32(x) elif isinstance(x, constexpr): - return _to_tensor(x.value, builder) - elif isinstance(x, tensor): + return _to_ir(x.value, builder) + elif isinstance(x, block): + return x.handle + elif isinstance(x, dtype): + return x.handle(builder) + return x + + +def _patch(fn): + def _from_ir(x): + if isinstance(x, ir.value): + if x.type.is_void(): + return None + return block(x) return x - assert False, f'cannot convert {x} to tensor' + + def wrapper(*args, **kwargs): + builder = args[-1] + assert isinstance(builder, ir.builder) + args = [_to_ir(x, builder) for x in args] + # for i, arg in enumerate(args): + # if arg is None: + # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") + kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} + ret = fn(*args, **kwargs) + if isinstance(ret, tuple): + return map(_from_ir, ret) + return _from_ir(ret) + + return wrapper + + +for name in dir(frontend): + fn = getattr(frontend, name) + if callable(fn): + setattr(frontend, name, _patch(fn)) def builtin(fn): @@ -45,147 +72,20 @@ def builtin(fn): class dtype: - SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] - UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] - FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] - OTHER_TYPES = ['void'] - - class SIGNEDNESS(Enum): - SIGNED = 0 - UNSIGNED = 1 - - def __init__(self, name): - self.name = name - assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name - if name in dtype.SINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.SIGNED - self.int_bitwidth = int(name.split('int')[-1]) - self.primitive_bitwidth = self.int_bitwidth - elif name in dtype.UINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.UNSIGNED - self.int_bitwidth = int(name.split('int')[-1]) - self.primitive_bitwidth = self.int_bitwidth - elif name in dtype.FP_TYPES: - if name == 'fp8': - self.fp_mantissa_width = 3 - self.primitive_bitwidth = 8 - elif name == 'fp16': - self.fp_mantissa_width = 10 - self.primitive_bitwidth = 16 - elif name == 'bf16': - self.fp_mantissa_width = 7 - self.primitive_bitwidth = 16 - elif name == 'fp32': - self.fp_mantissa_width = 23 - self.primitive_bitwidth = 32 - elif name == 'fp64': - self.fp_mantissa_width = 53 - self.primitive_bitwidth = 64 - elif name == 'void': - self.primitive_bitwidth = 0 - - def is_fp8(self): - return self.name == 'fp8' - - def is_fp16(self): - return self.name == 'fp16' - - def is_bf16(self): - return self.name == 'bf16' - - def is_fp32(self): - return self.name == 'fp32' - - def is_fp64(self): - return self.name == 'fp64' - - def is_int1(self): - return self.name == 'int1' - - def is_int8(self): - return self.name == 'int8' - - def is_int16(self): - return self.name == 'int16' - - def is_int32(self): - return self.name == 'int32' - - def is_int64(self): - return self.name == 'int64' - - def is_uint8(self): - return self.name == 'uint8' - - def is_uint16(self): - return self.name == 'uint16' - - def is_uint32(self): - return self.name == 'uint32' - - def is_uint64(self): - return self.name == 'uint64' - - def is_floating(self): - return self.name in dtype.FP_TYPES - - def is_int_signed(self): - return self.name in dtype.SINT_TYPES - - def is_int(self): - return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES - - def is_bool(self): - return self.is_int1() - - def is_void(self): - raise RuntimeError("Not implemented") - - def is_block(self): - return False - - def is_ptr(self): - return False - - def __eq__(self, other: dtype): - if not isinstance(other, dtype): - return False - return self.name == other.name - - def __ne__(self, other: dtype): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.name,)) + def __init__(self, init): + self.init = init @property - def scalar(self): - return self + def name(self) -> str: + # The init functions are named something like 'get_int8'. Strip the prefix. + nom = self.init.__name__ + prefix = 'get_' + assert nom.startswith(prefix) + return nom[len(prefix):] - def to_ir(self, builder: ir.builder) -> ir.type: - if self.name == 'void': - return builder.get_void_ty() - elif self.name == 'int1': - return builder.get_int1_ty() - elif self.name == 'int8' or self.name == 'uint8': - return builder.get_int8_ty() - elif self.name == 'int16' or self.name == 'uint16': - return builder.get_int16_ty() - elif self.name == 'int32' or self.name == 'uint32': - return builder.get_int32_ty() - elif self.name == 'int64' or self.name == 'uint64': - return builder.get_int64_ty() - elif self.name == 'fp8': - return builder.get_fp8_ty() - elif self.name == 'fp16': - return builder.get_half_ty() - elif self.name == 'bf16': - return builder.get_bf16_ty() - elif self.name == 'fp32': - return builder.get_float_ty() - elif self.name == 'fp64': - return builder.get_double_ty() - raise ValueError(f'fail to covert {self} to ir type') + def handle(self, builder): + ctx = builder.context + return self.init(ctx) def __str__(self): return self.name @@ -199,112 +99,36 @@ class dtype: return f'triton.language.{self.name}' -class pointer_type(dtype): - def __init__(self, element_ty: dtype, address_space: int = 1): +class pointer_dtype: + def __init__(self, element_ty): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty - self.address_space = address_space - self.name = self.__str__() - - def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) + def handle(self, builder): + return ir.type.make_ptr(self.element_ty.handle(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' - def __repr__(self): - return self.__str__() - - def is_ptr(self): - return True - - def __eq__(self, other: pointer_type) -> bool: - if not isinstance(other, pointer_type): - return False - return self.element_ty == other.element_ty and self.address_space == other.address_space - - def __ne__(self, other: pointer_type) -> bool: - return not self.__eq__(other) - - @property - def scalar(self): - return self - - -class block_type(dtype): - def __init__(self, element_ty: dtype, shape: List[int]): - self.element_ty = element_ty - # FIXME: - # block_type's shape is a list of int - # while tensor's shape is a list of constexpr - self.shape = shape - self.numel = 1 - for s in self.shape: - self.numel *= s - - self.name = self.__str__() - - def to_ir(self, builder: ir.builder) -> ir.block_type: - return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) - - def __str__(self): - return f'<{self.shape}, {self.element_ty}>' - - def __repr__(self): - return self.__str__() - - def is_block(self): - return True - - def get_block_shapes(self) -> List[int]: - return self.shape - - def __eq__(self, other: block_type) -> bool: - if not isinstance(other, block_type): - return False - return self.element_ty == other.element_ty and self.shape == other.shape - - def __ne__(self, other: block_type) -> bool: - return not self.__eq__(other) - - @property - def scalar(self): - return self.element_ty - - -class function_type(dtype): - def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: - self.ret_type = ret_type - self.param_types = param_types - - def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_type}' - - def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) - # scalar types -void = dtype('void') -int1 = dtype('int1') -int8 = dtype('int8') -int16 = dtype('int16') -int32 = dtype('int32') -int64 = dtype('int64') -uint8 = dtype('uint8') -uint16 = dtype('uint16') -uint32 = dtype('uint32') -uint64 = dtype('uint64') -float8 = dtype('fp8') -float16 = dtype('fp16') -bfloat16 = dtype('bf16') -float32 = dtype('fp32') -float64 = dtype('fp64') +int1 = dtype(ir.type.get_int1) +int8 = dtype(ir.type.get_int8) +int16 = dtype(ir.type.get_int16) +int32 = dtype(ir.type.get_int32) +int64 = dtype(ir.type.get_int64) +uint8 = dtype(ir.type.get_uint8) +uint16 = dtype(ir.type.get_uint16) +uint32 = dtype(ir.type.get_uint32) +uint64 = dtype(ir.type.get_uint64) +float8 = dtype(ir.type.get_fp8) +float16 = dtype(ir.type.get_fp16) +bfloat16 = dtype(ir.type.get_bf16) +float32 = dtype(ir.type.get_fp32) +float64 = dtype(ir.type.get_fp64) # pointer types -pi32_t = pointer_type(int32) +pi32_t = pointer_dtype(int32) # ----------------------- # constexpr @@ -325,6 +149,7 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" + # def __add__(self, other): return self.value + other.value @@ -394,33 +219,31 @@ class constexpr: return self.value(*args, **kwds) -class tensor: - # infer dtype from ir type +class block: @staticmethod - def _to_dtype(ir_type): - # block type - if ir_type.is_block(): - scalar_ty = tensor._to_dtype(ir_type.scalar) - return block_type(scalar_ty, ir_type.get_block_shapes()) - # pointer type - if ir_type.is_ptr(): - element_ty = tensor._to_dtype(ir_type.element) - return pointer_type(element_ty) + def _init_dtype(ir_type): # primitive type - if ir_type.is_void(): return void if ir_type.is_int1(): return int1 if ir_type.is_int8(): return int8 if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 + if ir_type.is_uint8(): return uint8 + if ir_type.is_uint16(): return uint16 + if ir_type.is_uint32(): return uint32 + if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 - raise ValueError(f"Unsupported type {ir_type.repr()}") + # pointer type + if ir_type.is_ptr(): + element_ty = block._init_dtype(ir_type.element) + return pointer_dtype(element_ty) + raise ValueError(f"Unsupported type {ir_type}") - def __init__(self, handle, type: dtype): + def __init__(self, handle): # IR handle self.handle = handle # Block shape @@ -431,9 +254,9 @@ class tensor: for s in self.shape: self.numel *= s self.numel = constexpr(self.numel) - self.type = type # Tensor type (can be block_type) - # Following the practice in pytorch, dtype is scalar type - self.dtype = type.scalar + # Data-type wrapper + self.dtype = block._init_dtype(self.handle.type.scalar) + # Shape is a constexpr self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: @@ -442,139 +265,116 @@ class tensor: @builtin def __add__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.add(self, other, _builder) + return frontend.add(self, other, _builder) def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(self, other, _builder) + return frontend.sub(self, other, _builder) def __rsub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(other, self, _builder) + return frontend.sub(other, self, _builder) @builtin def __mul__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.mul(self, other, _builder) + return frontend.mul(self, other, _builder) def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.truediv(self, other, _builder) + return frontend.truediv(self, other, _builder) def __rtruediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.truediv(other, self, _builder) + return frontend.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.floordiv(self, other, _builder) + return frontend.floordiv(self, other, _builder) @builtin def __mod__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.mod(self, other, _builder) + return frontend.mod(self, other, _builder) # unary operators @builtin def __neg__(self, _builder=None): - return semantic.minus(self, _builder) + return frontend.minus(self, _builder) @builtin def __invert__(self, _builder=None): - return semantic.invert(self, _builder) + return frontend.invert(self, _builder) # bitwise operators @builtin def __and__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.and_(self, other, _builder) + return frontend.and_(self, other, _builder) @builtin def __or__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.or_(self, other, _builder) + return frontend.or_(self, other, _builder) @builtin def __xor__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.xor_(self, other, _builder) + return frontend.xor_(self, other, _builder) @builtin def __lshift__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.shl(self, other, _builder) + return frontend.shl(self, other, _builder) @builtin def __rshift__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.lshr(self, other, _builder) + return frontend.lshr(self, other, _builder) # comparison operators # > @builtin def __gt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_than(self, other, _builder) + return frontend.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_than(other, self, _builder) + return frontend.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_equal(self, other, _builder) + return frontend.greater_equal(self, other, _builder) - @builtin def __rge__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.greater_equal(other, self, _builder) + return frontend.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_than(self, other, _builder) + return frontend.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - return semantic.less_than(other, self, _builder) + return frontend.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_equal(self, other, _builder) + return frontend.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.less_equal(other, self, _builder) + return frontend.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.equal(self, other, _builder) + return frontend.equal(self, other, _builder) @builtin def __ne__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.not_equal(self, other, _builder) + return frontend.not_equal(self, other, _builder) @builtin def __getitem__(self, slices, _builder=None): @@ -589,25 +389,20 @@ class tensor: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = semantic.reshape(self, dst_shape, _builder) + ret = frontend.reshape(self, dst_shape, _builder) return ret @builtin def to(self, dtype, bitcast=False, _builder=None): - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = dtype.handle(_builder) if bitcast: - return semantic.bitcast(self, dtype, _builder) - return semantic.cast(self, dtype, _builder) + return frontend.bitcast(self, dtype, _builder) + return frontend.cast(self, dtype, _builder) # ----------------------- # SPMD Programming Model # ----------------------- -def _constexpr_to_value(v): - if isinstance(v, constexpr): - return v.value - return v @builtin @@ -619,14 +414,13 @@ def program_id(axis, _builder=None): :type axis: int """ # if axis == -1: - # pid0 = program_id(0, _builder) - # pid1 = program_id(1, _builder) - # pid2 = program_id(2, _builder) - # npg0 = num_programs(0, _builder) - # npg1 = num_programs(0, _builder) + # pid0 = frontend.program_id(0, _builder) + # pid1 = frontend.program_id(1, _builder) + # pid2 = frontend.program_id(2, _builder) + # npg0 = frontend.num_programs(0, _builder) + # npg1 = frontend.num_programs(0, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 - axis = _constexpr_to_value(axis) - return semantic.program_id(axis, _builder) + return frontend.program_id(axis, _builder) @builtin @@ -637,8 +431,7 @@ def num_programs(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - axis = _constexpr_to_value(axis) - return semantic.num_programs(axis, _builder) + return frontend.num_programs(axis, _builder) # ----------------------- @@ -656,15 +449,13 @@ def arange(start, end, _builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - start = _constexpr_to_value(start) - end = _constexpr_to_value(end) - return semantic.arange(start, end, _builder) + return frontend.arange(start, end, _builder) @builtin def zeros(shape, dtype, _builder=None): """ - Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints @@ -677,8 +468,7 @@ def zeros(shape, dtype, _builder=None): if not isinstance(d.value, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") shape = [x.value for x in shape] - dtype = _constexpr_to_value(dtype) - return semantic.zeros(shape, dtype, _builder) + return frontend.zeros(shape, dtype, _builder) # ----------------------- @@ -691,25 +481,25 @@ def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. - :param input: The first input tensor. + :param input: The first input block. :type input: Block - :param other: The second input tensor. + :param other: The second input block. :type other: Block """ - return semantic.broadcast_impl_value(input, other, _builder) + return frontend.broadcast(input, other, _builder) @builtin def broadcast_to(input, shape, _builder=None): """ - Tries to broadcast the given tensor to a new :code:`shape`. + Tries to broadcast the given block to a new :code:`shape`. - :param input: The input tensor. + :param input: The input block. :type input: Block :param shape: The desired shape. :type shape: Tuple[int] """ - return semantic.broadcast_impl_shape(input, shape, _builder) + return frontend.broadcast_to(input, shape, _builder) @builtin @@ -717,27 +507,27 @@ def cat(input, other, _builder=None): """ Concatenate the given blocks - :param input: The first input tensor. + :param input: The first input block. :type input: - :param other: The second input tensor. + :param other: The second input block. :type other: """ - return semantic.cat(input, other, _builder) + return frontend.cat(input, other, _builder) @builtin def reshape(input, shape, _builder=None): """ - Tries to reshape the given tensor to a new shape. + Tries to reshape the given block to a new shape. - :param input: The input tensor. + :param input: The input block. :type input: :param shape: The desired shape. :type shape: Tuple[int] """ shape = [x.value for x in shape] - return semantic.reshape(input, shape, _builder) + return frontend.reshape(input, shape, _builder) # ----------------------- @@ -752,13 +542,12 @@ def dot(input, other, allow_tf32=True, _builder=None): The two blocks must be two dimensionals and have compatible inner dimensions. - :param input: The first tensor to be multiplied. - :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} - :param other: The second tensor to be multiplied. - :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input: The first block to be multiplied. + :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second block to be multiplied. + :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, _builder) + return frontend.dot(input, other, allow_tf32, _builder) # ----------------------- @@ -769,7 +558,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ - Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. + Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. @@ -784,36 +573,24 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - # mask, other can be constexpr - if mask is not None: - mask = _to_tensor(mask, _builder) - if other is not None: - other = _to_tensor(other, _builder) - cache_modifier = _constexpr_to_value(cache_modifier) - eviction_policy = _constexpr_to_value(eviction_policy) - volatile = _constexpr_to_value(volatile) - return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) + return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :param pointer: The memory locations where the elements of :code:`value` are stored. :type pointer: Block of dtype=triton.PointerDType - :param value: The tensor of elements to be stored. + :param value: The block of elements to be stored. :type value: Block :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - # value can be constexpr - value = _to_tensor(value, _builder) - if mask is not None: - mask = _to_tensor(mask, _builder) - return semantic.store(pointer, value, mask, _builder) + return frontend.store(pointer, value, mask, _builder) # ----------------------- @@ -844,58 +621,49 @@ def _add_atomic_docstr(name): @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): - cmp = _to_tensor(cmp, _builder) - val = _to_tensor(cmp, _builder) - return semantic.atomic_cas(pointer, cmp, val, _builder) + return frontend.atomic_cas(pointer, cmp, val, _builder) @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_xchg(pointer, val, mask, _builder) + return frontend.atomic_xchg(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_add(pointer, val, mask, _builder) + return frontend.atomic_add(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_max(pointer, val, mask, _builder) + return frontend.atomic_max(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_min(pointer, val, mask, _builder) + return frontend.atomic_min(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_and(pointer, val, mask, _builder) + return frontend.atomic_and(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_or(pointer, val, mask, _builder) + return frontend.atomic_or(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, _builder=None): - val = _to_tensor(val, _builder) - return semantic.atomic_xor(pointer, val, mask, _builder) + return frontend.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -906,7 +674,7 @@ def atomic_xor(pointer, val, mask=None, _builder=None): @builtin def where(condition, x, y, _builder=None): """ - Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. @@ -920,10 +688,7 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - condition = _to_tensor(condition, _builder) - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) - return semantic.where(condition, x, y, _builder) + return frontend.where(condition, x, y, _builder) # ----------------------- @@ -932,15 +697,12 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) - return semantic.umulhi(x, y, _builder) + return frontend.umulhi(x, y, _builder) @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): - ieee_rounding = _constexpr_to_value(ieee_rounding) - return semantic.fdiv(x, y, ieee_rounding, _builder) + return frontend.fdiv(x, y, ieee_rounding, _builder) def _add_math_1arg_docstr(name): @@ -961,31 +723,31 @@ def _add_math_1arg_docstr(name): @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): - return semantic.exp(x, _builder) + return frontend.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): - return semantic.log(x, _builder) + return frontend.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): - return semantic.cos(x, _builder) + return frontend.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): - return semantic.sin(x, _builder) + return frontend.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): - return semantic.sqrt(x, _builder) + return frontend.sqrt(x, _builder) # ----------------------- @@ -996,7 +758,7 @@ def _add_reduction_docstr(name): def _decorator(func): docstr = """ - Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done @@ -1010,29 +772,25 @@ def _add_reduction_docstr(name): @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.max(input, axis, _builder) + return frontend.max(input, axis, _builder) @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.min(input, axis, _builder) + return frontend.min(input, axis, _builder) @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.sum(input, axis, _builder) + return frontend.sum(input, axis, _builder) @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): - axis = _constexpr_to_value(axis) - return semantic.xor_sum(input, axis, _builder) + return frontend.xor_sum(input, axis, _builder) # ----------------------- @@ -1042,7 +800,7 @@ def xor_sum(input, axis, _builder=None): @builtin def debug_barrier(_builder=None): - return semantic.debug_barrier(_builder) + return frontend.debug_barrier(_builder) @builtin @@ -1050,8 +808,7 @@ def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - value = _constexpr_to_value(value) - return semantic.multiple_of(input, value) + return frontend.multiple_of(input, value, _builder) @builtin @@ -1059,8 +816,7 @@ def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - value = _constexpr_to_value(value) - return semantic.max_contiguous(input, value) + return frontend.max_contiguous(input, value, _builder) # ----------------------- @@ -1090,9 +846,9 @@ def minimum(x, y): """ Computes the element-wise minimum of :code:`x` and :code:`y`. - :param input: the first input tensor + :param input: the first input block :type input: Block - :param other: the second input tensor + :param other: the second input block :type other: Block """ return triton.language.where(x < y, x, y) @@ -1103,9 +859,9 @@ def maximum(x, y): """ Computes the element-wise maximum of :code:`x` and :code:`y`. - :param input: the first input tensor + :param input: the first input block :type input: Block - :param other: the second input tensor + :param other: the second input block :type other: Block """ return triton.language.where(x > y, x, y) @@ -1131,7 +887,7 @@ def ravel(x): """ Returns a contiguous flattened view of :code:`x` - :param x: the input tensor + :param x: the input block :type x: Block """ return triton.language.reshape(x, [x.numel]) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py deleted file mode 100644 index 4063b86fc..000000000 --- a/python/triton/language/semantic.py +++ /dev/null @@ -1,1037 +0,0 @@ -from __future__ import annotations # remove after python 3.11 - -from typing import List, Optional, Tuple - -from . import core as tl -from triton._C.libtriton.triton import ir - - -# Create custom exception that prints message "hello" -class IncompatibleTypeErrorimpl(Exception): - def __init__(self, type_a, type_b): - self.type_a = type_a - self.type_b = type_b - self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() - super(IncompatibleTypeErrorimpl, self).__init__(self.message) - - -# ===----------------------------------------------------------------------===## -# Programming Model -# ===----------------------------------------------------------------------===## - -def program_id(axis: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_get_program_id(axis), tl.int32) - - -def num_programs(axis: int, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_get_num_programs(axis), tl.int32) - -# ===----------------------------------------------------------------------===// -# Implicit Casting Utilities -# ===----------------------------------------------------------------------===// - - -def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: - a_rank = a_ty.int_bitwidth - b_rank = b_ty.int_bitwidth - a_sn = a_ty.int_signedness - b_sn = b_ty.int_signedness - # Rules for signedness taken from "Usual arithmetic conversions" on - # https://en.cppreference.com/w/c/language/conversion. - if a_sn == b_sn: - return a_ty if a_rank > b_rank else b_ty - elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: - return a_ty if a_rank >= b_rank else b_ty - elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: - return b_ty if b_rank >= a_rank else a_ty - assert False - - -def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: - # 1) if one operand is double, the other is implicitly - # converted to double - if a_ty.is_fp64() or b_ty.is_fp64(): - return tl.float64 - # 2) if one operand is float, the other is implicitly - # converted to float - if a_ty.is_fp32() or b_ty.is_fp32(): - return tl.float32 - # 3 ) if one operand is half, the other is implicitly converted to half - # unless we're doing / or %, which do not exist natively in PTX for fp16. - if a_ty.is_fp16() or b_ty.is_fp16(): - if div_or_mod: - return tl.float32 - else: - return tl.float16 - if not a_ty.is_int() or not b_ty.is_int(): - assert False - # 4 ) both operands are integer and undergo - # integer promotion - if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: - raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" - "this is unlikely to result in a useful answer. Cast them to the same signedness.") - return integer_promote_impl(a_ty, b_ty) - -# ===----------------------------------------------------------------------===// -# Binary Operators -# ===----------------------------------------------------------------------===// - - -def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: - if type_a.is_ptr(): - if not allow_ptr_a: - raise IncompatibleTypeErrorimpl(type_a, type_b) - # T* + U* with T != U - if type_b.is_ptr() and (type_a != type_b): - raise IncompatibleTypeErrorimpl(type_a, type_b) - # T* + float - if type_b.is_floating(): - raise IncompatibleTypeErrorimpl(type_a, type_b) - - -def binary_op_type_checking_impl(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder, - allow_lhs_ptr=False, allow_rhs_ptr=False, - arithmetic_check=True, div_or_mod=False - ) -> Tuple[tl.tensor, tl.tensor]: - # implicit broadcasting - lhs, rhs = broadcast_impl_value(lhs, rhs, builder) - # implicit typecasting - lhs_sca_ty = lhs.type.scalar - rhs_sca_ty = rhs.type.scalar - check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) - check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) - if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): - ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) - lhs = cast(lhs, ret_sca_ty, builder) - rhs = cast(rhs, ret_sca_ty, builder) - return lhs, rhs - - -def add(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - - # offset + ptr - # ptr + offset - if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): - input, other = other, input - if input_scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) - # float + float - elif input_scalar_ty.is_floating(): - return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) - # int + int - elif input_scalar_ty.is_int(): - return tl.tensor(builder.create_add(input.handle, other.handle), input.type) - assert False - - -def sub(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, True, False) - scalar_ty = input.type.scalar - # ptr - offset - if scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), - input.type) - # float - float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) - # int - int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) - assert False - - -def mul(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float * float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) - # * int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) - assert False - - -def truediv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - # float / int - if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): - other = cast(other, input_scalar_ty, builder) - # int / float - elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): - input = cast(input, other_scalar_ty, builder) - # int / int (cast to tl.float32) - elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): - input = cast(input, tl.float32, builder) - other = cast(other, tl.float32, builder) - # float / float (cast to highest exponent type) - elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): - if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: - other = cast(other, input_scalar_ty, builder) - else: - input = cast(input, other_scalar_ty, builder) - # unreachable - else: - assert False - return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) - - -def floordiv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - if input_scalar_ty.is_int() and other_scalar_ty.is_int(): - ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) - input = cast(input, ret_ty, builder) - other = cast(other, ret_ty, builder) - if ret_ty.is_int_signed(): - return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) - else: - return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) - assert False - - -def fdiv(input: tl.tensor, - other: tl.tensor, - ieee_rounding: bool, - builder: ir.builder) -> tl.tensor: - input_scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): - raise ValueError("both operands of fdiv must have floating poscalar type") - input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) - ret = builder.create_fdiv(input.handle, other.handle) - ret.set_fdiv_ieee_rounding(ieee_rounding) - return tl.tensor(ret, input.type) - - -def mod(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) - scalar_ty = input.type.scalar - other_scalar_ty = other.type.scalar - # float % float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) - # % int - elif scalar_ty.is_int(): - if scalar_ty.int_signedness != other_scalar_ty.int_signedness: - raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " - "because they have different signedness;" - "this is unlikely to result in a useful answer. Cast them to the same signedness.") - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) - else: - return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) - assert False - -############## -# bitwise ops -############## - - -def bitwise_op_type_checking_impl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) - input_sca_ty = input.type.scalar - other_sca_ty = other.type.scalar - if not input_sca_ty.is_int() or not other_sca_ty.is_int(): - raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty) - ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) - if ret_sca_ty != input_sca_ty: - input = cast(input, ret_sca_ty, builder) - if ret_sca_ty != other_sca_ty: - other = cast(other, ret_sca_ty, builder) - return input, other - - -def and_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_and(input.handle, other.handle), input.type) - - -def or_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_or(input.handle, other.handle), input.type) - - -def xor_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) - - -def lshr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) - - -def shl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = bitwise_op_type_checking_impl(input, other, builder) - return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) - -# ===----------------------------------------------------------------------===// -# Unary Operators -# ===----------------------------------------------------------------------===// - - -def plus(input: tl.tensor) -> tl.tensor: - return input - - -def minus(input: tl.tensor, - builder: ir.builder) -> tl.tensor: - input_sca_ty = input.type.scalar - if input_sca_ty.is_ptr(): - raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") - _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return sub(_0, input, builder) - - -def invert(input: tl.tensor, - builder: tl.tensor) -> tl.tensor: - input_sca_ty = input.type.scalar - if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): - raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") - _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return xor_(input, _1, builder) - - -# ===----------------------------------------------------------------------===// -# Comparison Operators -# ===----------------------------------------------------------------------===// -def _bool_like(v: tl.tensor) -> tl.block_type: - if not v.type.is_block(): - return tl.int1 - shape = v.type.shape - return tl.block_type(tl.int1, shape) - - -def greater_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float > float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) - # > int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) - assert False - - -def greater_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float >= float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) - # >= int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) - assert False - - -def less_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float < float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) - # < int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) - assert False - - -def less_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float < float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) - # < int - elif scalar_ty.is_int(): - if scalar_ty.is_int_signed(): - return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) - else: - return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) - assert False - - -def equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float == float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) - # == int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) - assert False - - -def not_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: - input, other = binary_op_type_checking_impl(input, other, builder) - scalar_ty = input.type.scalar - # float == float - if scalar_ty.is_floating(): - return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) - # == int - elif scalar_ty.is_int(): - return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) - assert False - -# ===----------------------------------------------------------------------===// -# Block Creation -# ===----------------------------------------------------------------------===// - - -def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: - shape = [end - start] - ret_ty = tl.block_type(tl.int32, shape) - return tl.tensor(builder.get_range(start, end), ret_ty) - - -def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - _0 = ir.constant.get_null_value(dtype.to_ir(builder)) - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_splat(_0, shape), ret_ty) - -# ===----------------------------------------------------------------------===// -# Shape Manipulation -# ===----------------------------------------------------------------------===// - - -def reshape(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: - numel = 1 - for s in dst_shape: - numel *= s - if input.type.numel != numel: - raise ValueError("cannot reshape block of different shape") - ret_ty = tl.block_type(input.type.scalar, dst_shape) - return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) - - -def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: - # TODO: check types - return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) - - -def broadcast_impl_shape(input: tl.tensor, - shape: List[int], - builder: ir.builder) -> tl.tensor: - if not input.type.is_block(): - ret_ty = tl.block_type(input.type, shape) - return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) - src_shape = input.type.get_block_shapes() - if len(src_shape) != len(shape): - raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") - if shape == src_shape: - return input - ret_ty = tl.block_type(input.type.scalar, shape) - return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) - - -def broadcast_impl_value(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder) -> tl.tensor: - lhs_ty = lhs.type - rhs_ty = rhs.type - - # make_shape_compatible(block, scalar) - if lhs_ty.is_block() and not rhs_ty.is_block(): - rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) - rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) - # make_shape_compatible(scalar, block) - elif not lhs_ty.is_block() and rhs_ty.is_block(): - lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) - lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) - # make_shape_compatible(block, block) - elif lhs_ty.is_block() and rhs_ty.is_block(): - lhs_shape = lhs_ty.get_block_shapes() - rhs_shape = rhs_ty.get_block_shapes() - if len(lhs_shape) != len(rhs_shape): - raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") - ret_shape = [] - for i in range(len(lhs_shape)): - left = lhs_shape[i] - right = rhs_shape[i] - if left == 1: - ret_shape.append(right) - elif right == 1: - ret_shape.append(left) - elif left == right: - ret_shape.append(left) - else: - raise ValueError("Cannot make_shape_compatible: incompatible dimensions " - "at index " + str(i) + ": " + str(left) + " and " + str(right)) - if lhs_shape != ret_shape: - ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) - lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) - if rhs_shape != ret_shape: - ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) - rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) - # (scalar, scalar) => returns original blocks - return lhs, rhs - -####### -# cast -####### - - -def bitcast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: - src_ty = input.type - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): - return cast(input, dst_ty, builder) - # Bitcast - src_bits = src_sca_ty.primitive_bitwidth - dst_bits = dst_sca_ty.primitive_bitwidth - if src_bits != dst_bits: - raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " - "data-type of size " + str(dst_bits)) - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), - dst_ty) - - -def cast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: - src_ty = input.type - if src_ty.is_block(): - dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) - if src_ty == dst_ty: - return input - src_sca_ty = src_ty.scalar - dst_sca_ty = dst_ty.scalar - - # bf16 <=> (not fp32) - if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ - (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): - return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) - - # FP Truncation - truncate_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width - if truncate_fp: - return tl.tensor(builder.create_fp_trunc(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # FP Extension - ext_fp = src_sca_ty.is_floating() and \ - dst_sca_ty.is_floating() and \ - src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width - if ext_fp: - return tl.tensor(builder.create_fp_ext(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # Int cast - if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ - (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): - sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() - return tl.tensor(builder.create_int_cast(input.handle, - dst_ty.to_ir(builder), sign_extend), - dst_ty) - - # Float to Int - if src_sca_ty.is_floating() and dst_sca_ty.is_int(): - # TODO: is this correct? - if dst_sca_ty.is_bool(): - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - else: - return tl.tensor(builder.create_fp_to_si(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # int => float - if src_sca_ty.is_int() and dst_sca_ty.is_floating(): - if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tl.tensor(builder.create_ui_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - else: - return tl.tensor(builder.create_si_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) - - # ptr => int - if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): - bitwidth = dst_sca_ty.int_bitwidth - if bitwidth == 64: - return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), - dst_ty) - if bitwidth == 1: - return not_equal(cast(input, tl.int64, builder), - tl.tensor(builder.get_int64(0), tl.int64), - builder) - - if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) - # Ptr . Ptr - if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) - # * . Bool - if dst_sca_ty.is_bool(): - if src_sca_ty.is_ptr(): - input = cast(input, tl.int64, builder) - other = builder.get_int64(0) - if src_ty.is_bool(): - other = builder.create_splat(other, src_ty.get_block_shapes()) - return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) - assert False, f'cannot cast {input} to {dst_ty}' - -# ===----------------------------------------------------------------------===// -# Memory Operators -# ===----------------------------------------------------------------------===// - - -def load(ptr: tl.tensor, - mask: Optional[tl.tensor], - other: Optional[tl.tensor], - cache_modifier: str, - eviction_policy: str, - is_volatile: bool, - builder: ir.builder) -> tl.tensor: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - if other: - other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) - - if other: - other = cast(other, ptr.type.scalar.element_ty, builder) - ptr_ty = ptr.type.scalar - elt_ty = ptr_ty.element_ty - # treat bool* as tl.int8* - if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) - ptr = cast(ptr, ptr_ty, builder) - - # cache modifier - cache = ir.CACHE_MODIFIER.NONE # default - if cache_modifier: - if cache_modifier == ".ca": - cache = ir.CACHE_MODIFIER.CA - elif cache_modifier == ".cg": - cache = ir.CACHE_MODIFIER.CG - else: - raise ValueError(f"Cache modifier {cache_modifier} not supported") - - # eviction policy - eviction = ir.EVICTION_POLICY.NORMAL # default - if eviction_policy: - if eviction_policy == "evict_last": - eviction = ir.EVICTION_POLICY.EVICT_LAST - elif eviction_policy == "evict_first": - eviction = ir.EVICTION_POLICY.EVICT_FIRST - else: - raise ValueError(f"Eviction policy {eviction_policy} not supported") - - if ptr.type.is_block(): - shape = ptr.type.get_block_shapes() - dst_ty = tl.block_type(elt_ty, shape) - else: - dst_ty = elt_ty - - if not mask and not other: - return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), - dst_ty) - if not mask: - raise ValueError("`other` cannot be provided without `mask`") - - if not other: - other_ir = ir.undef.get(elt_ty.to_ir(builder)) - if ptr.type.is_block(): - other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes()) - other = tl.tensor(other_ir, dst_ty) - - return tl.tensor(builder.create_masked_load(ptr.handle, - mask.handle, - other.handle, - cache, eviction, is_volatile), - dst_ty) - - -def store(ptr: tl.tensor, - val: tl.tensor, - mask: Optional[tl.tensor], - builder: ir.builder) -> tl.tensor: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - ptr_ty = ptr.type.scalar - elt_ty = ptr_ty.element_ty - # treat bool* as tl.int8* - if elt_ty == tl.int1: - elt_ty = tl.int8 - ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) - ptr = cast(ptr, ptr_ty, builder) - - # cast to target data-type - val = cast(val, elt_ty, builder) - if not mask: - return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) - if not mask.type.scalar.is_bool(): - raise ValueError("Mask must have boolean scalar type") - return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) - -######### -# atomic -######### - - -def atomic_cas(ptr: tl.tensor, - cmp: tl.tensor, - val: tl.tensor, - builder: ir.builder) -> tl.tensor: - # TODO: type checking - return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) - - -def atom_red_typechecking_impl(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: - if not ptr.type.scalar.is_ptr(): - raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) - if ptr.type.is_block(): - if mask: - mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) - if val: - val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) - val = cast(val, ptr.type.scalar.element_ty, builder) - if not mask: - mask_ir = builder.get_int1(True) - mask_ty = tl.int1 - if ptr.type.is_block(): - mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) - mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) - mask = tl.tensor(mask_ir, mask_ty) - return ptr, val, mask - - -def atomic_max(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - # direct call to atomic_max for integers - if sca_ty.is_int(): - if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, - ptr.handle, - val.handle, - mask.handle), - val.type) - else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - ptr.handle, - val.handle, - mask.handle), - val.type) - # for float - # return atomic_smax(i_ptr, i_val) if val >= 0 - # return atomic_umin(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) - return where(pos, pos_ret, neg_ret, builder) - - -def atomic_min(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - # direct call to atomic_min for integers - if sca_ty.is_int(): - if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - ptr.handle, - val.handle, - mask.handle), - val.type) - else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, - ptr.handle, - val.handle, - mask.handle), - val.type) - # for float - # return atomic_smin(i_ptr, i_val) if val >= 0 - # return atomic_umax(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - i_ptr.handle, - i_val.handle, - and_(mask, pos, builder).handle), - i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - i_ptr.handle, - i_val.handle, - and_(mask, neg, builder).handle), - i_val.type) - return where(pos, pos_ret, neg_ret, builder) - - -def atomic_add(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - sca_ty = val.type.scalar - op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD - return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_and(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_or(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_xor(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) - - -def atomic_xchg(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - builder: ir.builder) -> tl.tensor: - ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) - -# ===----------------------------------------------------------------------===// -# Linear Algebra -# ===----------------------------------------------------------------------===// - - -def dot(lhs: tl.tensor, - rhs: tl.tensor, - allow_tf32: bool, - builder: ir.builder) -> tl.tensor: - assert lhs.type.is_block() and rhs.type.is_block() - if lhs.type.scalar.is_int(): - _0 = builder.get_int32(0) - ret_scalar_ty = tl.int32 - else: - _0 = builder.get_float32(0) - ret_scalar_ty = tl.float32 - M = lhs.type.shape[0] - N = rhs.type.shape[1] - _0 = builder.create_splat(_0, [M, N]) - ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) - - -# ===----------------------------------------------------------------------===// -# Indexing -# ===----------------------------------------------------------------------===// - -def where(condition: tl.tensor, - x: tl.tensor, - y: tl.tensor, - builder: ir.builder) -> tl.tensor: - condition = cast(condition, tl.int1, builder) - if condition.type.is_block(): - x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) - y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) - - # TODO: we need to check x's and y's shape? - x_ty = x.type.scalar - y_ty = y.type.scalar - ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) - x = cast(x, ty, builder) - y = cast(y, ty, builder) - if x.type.is_block(): - ret_ty = tl.block_type(ty, x.type.shape) - else: - ret_ty = ty - return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) - - -# ===----------------------------------------------------------------------===// -# Reductions -# ===----------------------------------------------------------------------=== - -def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, - FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: - scalar_ty = input.type.scalar - # input is extended to 32-bits if necessary - # this increases numerical accuracy and can be done pretty much for free - # on GPUs - if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: - input = cast(input, tl.int32, builder) - - # get result type - shape = input.type.shape - ret_shape = [] - for i, s in enumerate(shape): - if i != axis: - ret_shape.append(s) - if len(ret_shape) == 0: - res_ty = scalar_ty - else: - res_ty = tl.block_type(scalar_ty, ret_shape) - - if scalar_ty.is_floating(): - return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) - elif scalar_ty.is_int(): - return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) - assert False - - -def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) - - -def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) - - -def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) - - -def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - if not scalar_ty.is_int(): - raise ValueError("xor_sum only supported for integers") - return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) - - -# ===----------------------------------------------------------------------=== -# Math -# ===----------------------------------------------------------------------=== - -def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: - x, y = binary_op_type_checking_impl(x, y, builder) - return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) - - -def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_exp(x.handle), x.type) - - -def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_log(x.handle), x.type) - - -def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_cos(x.handle), x.type) - - -def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_sin(x.handle), x.type) - - -def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_sqrt(x.handle), x.type) - - -## - -def multiple_of(x: tl.tensor, value: int) -> tl.tensor: - x.handle.multiple_of(value) - return x - - -def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: - x.handle.max_contiguous(value) - return x - - -def debug_barrier(builder: ir.builder) -> tl.tensor: - return tl.tensor(builder.create_barrier(''), tl.void)