From 539961072c277221e4e1dfda81e718daf9bbc1c7 Mon Sep 17 00:00:00 2001 From: daadaada Date: Thu, 17 Mar 2022 12:25:30 +0800 Subject: [PATCH] [FRONTEND] Semantic analysis refactor (#473) Moved dispatch.cc to semantic.py Integer signedness now moved from C++ to python Cleaner frontend type Co-authored-by: Phil Tillet --- include/triton/ir/builder.h | 31 +- include/triton/ir/context_impl.h | 1 - include/triton/ir/dispatch.h | 113 --- include/triton/ir/module.h | 33 +- include/triton/ir/type.h | 20 +- lib/ir/builder.cc | 53 +- lib/ir/context.cc | 18 +- lib/ir/dispatch.cc | 882 ----------------- lib/ir/instructions.cc | 4 +- lib/ir/module.cc | 140 --- lib/ir/type.cc | 14 - python/src/triton.cc | 392 +++++--- python/test/regression/test_performance.py | 2 +- python/test/unit/language/test_core.py | 23 - python/test/unit/runtime/test_cache.py | 28 + python/triton/__init__.py | 3 +- python/triton/code_gen.py | 314 ++++-- python/triton/language/core.py | 606 ++++++++---- python/triton/language/semantic.py | 1037 ++++++++++++++++++++ 19 files changed, 2044 insertions(+), 1670 deletions(-) delete mode 100644 include/triton/ir/dispatch.h delete mode 100644 lib/ir/dispatch.cc create mode 100644 python/triton/language/semantic.py diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 2b6bc6ab3..fe85be947 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -38,10 +38,8 @@ public: iterator get_insert_point() { return insert_point_;} // Constants value *get_int1(bool val); - value *get_int32(int32_t val); - value *get_int64(int64_t val); - value *get_uint32(uint32_t val); - value *get_uint64(uint64_t val); + value *get_int32(uint32_t val); + value *get_int64(uint64_t val); value *get_float16(float val); value *get_float32(float val); value *get_range(int32_t lo, int32_t hi); @@ -52,11 +50,9 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); - type *get_uint8_ty(); - type *get_uint16_ty(); - type *get_uint32_ty(); - type *get_uint64_ty(); + type *get_fp8_ty(); type *get_half_ty(); + type *get_bf16_ty(); type *get_float_ty(); type *get_double_ty(); // Insert @@ -74,7 +70,9 @@ public: value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); // Cast instructions + value* create_bitcast(value *src, type *dest_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty); + value* create_int_to_ptr(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty); @@ -93,11 +91,11 @@ public: value *create_frem(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs); - value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sdiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs); + value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); @@ -145,11 +143,22 @@ public: value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); + // Atomic instruction + value *create_atomic_cas(value *ptr, value *cmp, value *val); + value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk); + value *create_atomic_max(value *ptr, value *val, value *msk); + value *create_atomic_umax(value *ptr, value *val, value *msk); + value *create_atomic_min(value *ptr, value *val, value *msk); + value *create_atomic_umin(value *ptr, value *val, value *msk); + value *create_atomic_fadd(value *ptr, value *val, value *msk); + value *create_atomic_add(value *ptr, value *val, value *msk); + value *create_atomic_and(value *ptr, value *val, value *msk); + value *create_atomic_or(value *ptr, value *val, value *msk); + value *create_atomic_xor(value *ptr, value *val, value *msk); + value *create_atomic_xchg(value *ptr, value *val, value *msk); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); - value *create_atomic_cas(value *ptr, value *cmp, value *val); - value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk); value *create_exp(value* arg); value *create_cos(value* arg); value *create_sin(value* arg); diff --git a/include/triton/ir/context_impl.h b/include/triton/ir/context_impl.h index 081ea249d..ef20af6b7 100644 --- a/include/triton/ir/context_impl.h +++ b/include/triton/ir/context_impl.h @@ -26,7 +26,6 @@ public: type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; // integer types integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; - integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty; // Pointer types std::map, std::unique_ptr> ptr_tys; // Block types diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h deleted file mode 100644 index ef14043dd..000000000 --- a/include/triton/ir/dispatch.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#ifndef _TRITON_IR_DISPATCH_H_ -#define _TRITON_IR_DISPATCH_H_ - -#include "triton/ir/builder.h" -#include - -namespace triton{ -namespace ir{ - - -/*---------------------------------------------- - higher level functions that follow the likely - semantics of most expected frontends - ----------------------------------------------*/ - -struct semantic_error: public std::runtime_error { - semantic_error(const std::string& msg): - std::runtime_error(msg) { } -}; - -struct dispatch{ - typedef ir::type::block_shapes_t shape_t; - - - // programming model - static ir::value *program_id(int axis, ir::builder *builder); - static ir::value *num_programs(int axis, ir::builder *builder); - - // binary operators - static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder); - static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder); - - // unary operators - static ir::value *plus(ir::value *input, ir::builder *builder); - static ir::value *minus(ir::value *input, ir::builder *builder); - static ir::value *invert(ir::value *input, ir::builder *builder); - - // comparison operators - static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder); - static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder); - - // block creation - static ir::value* arange(int start, int end, ir::builder *builder); - static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder); - - - // casting ops - static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); - static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); - static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); - static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); - static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); - static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder); - - // memory operators - static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, - const std::string& eviction_policy, int is_volatile, ir::builder *builder); - static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder); - static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder); - static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder); - - // linear algebra - static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder); - - // indexing - static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder); - - // reduction - static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); - static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder); - - // math - static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); - static ir::value *exp(ir::value *x, ir::builder *builder); - static ir::value *log(ir::value *x, ir::builder *builder); - static ir::value *cos(ir::value *x, ir::builder *builder); - static ir::value *sin(ir::value *x, ir::builder *builder); - static ir::value *sqrt(ir::value *x, ir::builder *builder); - - // internal (debug/optimization) - static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder); - static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder); - static ir::value *debug_barrier(ir::builder *builder); -}; - -} -} - -#endif diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 30881fd49..ea64dfc6e 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -57,26 +57,10 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, builder& builder); - builder& get_builder(); - // Setters - void set_value(const std::string& name, basic_block* block, value *x); - void set_value(const std::string& name, value* x); - void set_const(const std::string& name); - void set_continue_fn(std::function fn); - // Getters - const std::map& get_values() { return values_; } - const std::map& get_types() { return types_; } - void set_values(const std::map& values) { values_ = values; } - void set_types(const std::map& types) { types_ = types; } + module(const std::string &name, builder &builder): name_(name), builder_(builder) {} + builder &get_builder() { return builder_; }; + const std::string& get_name() { return name_; }; - value *get_value(const std::string& name, basic_block* block); - value *get_value(const std::string& name); - void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } - const std::string& get_name(); - std::function get_continue_fn(); - // Seal block -- no more predecessors will be added - void seal_block(basic_block *block); // Functions const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } @@ -89,21 +73,14 @@ public: const std::map& globals() const { return globals_; } // Metadata void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } - + const std::map &get_metadatas() const { return metadatas_; } void print(std::ostream &os); private: std::string name_; - builder& builder_; - std::map values_; - std::map types_; - std::set const_; - std::set sealed_blocks_; - std::map> incomplete_phis_; + builder &builder_; functions_list_t functions_; symbols_map_t symbols_; - std::function continue_fn_; - std::map current_phi_; std::vector allocs_; std::map globals_; std::map metadatas_; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index 47c9b5f85..b1ef1ad22 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -16,8 +16,6 @@ class value; class integer_type; class constant_int; -enum class signedness { SIGNED, UNSIGNED }; - /* Type */ class type { public: @@ -61,8 +59,6 @@ public: // type attributes unsigned get_fp_mantissa_width() const; unsigned get_integer_bitwidth() const; - signedness get_integer_signedness() const; - bool is_integer_signed() const; unsigned get_tile_bitwidth() const; unsigned get_primitive_size_in_bits() const; type *get_scalar_ty() const; @@ -85,9 +81,6 @@ public: bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_token_ty() const { return id_ == TokenTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; } - bool is_integer_ty(unsigned bitwidth, signedness sn) { - return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn; - } bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_block_ty() const { return id_ == BlockTyID; } @@ -115,10 +108,6 @@ public: static integer_type *get_int32_ty(context &ctx); static integer_type *get_int64_ty(context &ctx); static integer_type *get_int128_ty(context &ctx); - static integer_type *get_uint8_ty(context &ctx); - static integer_type *get_uint16_ty(context &ctx); - static integer_type *get_uint32_ty(context &ctx); - static integer_type *get_uint64_ty(context &ctx); // repr std::string tile_repr() const { @@ -145,7 +134,7 @@ public: case LabelTyID: return "label"; case MetadataTyID: return "md"; case TokenTyID: return "tok"; - case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); + case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); case FunctionTyID: return "fn"; case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case StructTyID: return "struct"; @@ -168,21 +157,18 @@ class integer_type: public type { private: // constructors - integer_type(context &ctx, unsigned bitwidth, signedness sn) - : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } + integer_type(context &ctx, unsigned bitwidth) + : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} public: // accessors unsigned get_bitwidth() const { return bitwidth_; } - signedness get_signedness() const { return signedness_; } - // factory methods static integer_type* get(context &ctx, unsigned width); private: unsigned bitwidth_; - signedness signedness_; }; class composite_type: public type{ diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index fff73e665..9b8a2a45e 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -48,18 +48,12 @@ void builder::set_insert_point(basic_block *block){ value *builder::get_int1(bool val) { return constant_int::get(type::get_int1_ty(ctx_), val); } -value *builder::get_int32(int32_t val) +value *builder::get_int32(uint32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} -value *builder::get_uint32(uint32_t val) -{ return constant_int::get(type::get_uint32_ty(ctx_), val);} - -value *builder::get_int64(int64_t val) +value *builder::get_int64(uint64_t val) { return constant_int::get(type::get_int64_ty(ctx_), val);} -value *builder::get_uint64(uint64_t val) -{ return constant_int::get(type::get_uint64_ty(ctx_), val);} - value *builder::get_float16(float val) { return constant_fp::get(type::get_fp16_ty(ctx_), val); } @@ -90,21 +84,15 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } -type *builder::get_uint8_ty() -{ return type::get_uint8_ty(ctx_); } - -type *builder::get_uint16_ty() -{ return type::get_uint16_ty(ctx_); } - -type *builder::get_uint32_ty() -{ return type::get_uint32_ty(ctx_); } - -type *builder::get_uint64_ty() -{ return type::get_uint64_ty(ctx_); } +type *builder::get_fp8_ty() +{ return type::get_fp8_ty(ctx_); } type *builder::get_half_ty() { return type::get_fp16_ty(ctx_); } +type *builder::get_bf16_ty() +{ return type::get_bf16_ty(ctx_); } + type *builder::get_float_ty() { return type::get_fp32_ty(ctx_); } @@ -139,6 +127,8 @@ value *builder::create_ret_void() { return create_cast(OPCODE, src, dst_ty);\ } +DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast) +DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) @@ -331,6 +321,28 @@ value *builder::create_downcast(value *arg) { return insert(downcast_inst::create(arg)); } +// + +value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ + return insert(atomic_rmw_inst::create(op, ptr, val, msk)); +} + +#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\ + value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\ + return create_atomic_rmw(OPCODE, ptr, val, mask);\ + } + +DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max) +DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax) +DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min) +DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin) +DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd) +DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add) +DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And) +DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or) +DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor) +DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg) + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -347,9 +359,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){ return insert(atomic_cas_inst::create(ptr, cmp, val)); } -value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){ - return insert(atomic_rmw_inst::create(op, ptr, val, msk)); -} value *builder::create_exp(value *arg){ return insert(exp_inst::create(arg)); diff --git a/lib/ir/context.cc b/lib/ir/context.cc index 90b109b9b..0fc65ddc2 100644 --- a/lib/ir/context.cc +++ b/lib/ir/context.cc @@ -19,18 +19,12 @@ context_impl::context_impl(context &ctx) fp32_ty(ctx, type::FP32TyID), fp64_ty(ctx, type::FP64TyID), // integers - int1_ty(ctx, 1, signedness::SIGNED), - int8_ty(ctx, 8, signedness::SIGNED), - int16_ty(ctx, 16, signedness::SIGNED), - int32_ty(ctx, 32, signedness::SIGNED), - int64_ty(ctx, 64, signedness::SIGNED), - int128_ty(ctx, 128, signedness::SIGNED), - uint8_ty(ctx, 8, signedness::UNSIGNED), - uint16_ty(ctx, 16, signedness::UNSIGNED), - uint32_ty(ctx, 32, signedness::UNSIGNED), - uint64_ty(ctx, 64, signedness::UNSIGNED){ - -} + int1_ty(ctx, 1), + int8_ty(ctx, 8), + int16_ty(ctx, 16), + int32_ty(ctx, 32), + int64_ty(ctx, 64), + int128_ty(ctx, 128) {} //===----------------------------------------------------------------------===// // context diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc deleted file mode 100644 index 664fbb983..000000000 --- a/lib/ir/dispatch.cc +++ /dev/null @@ -1,882 +0,0 @@ -#include "triton/ir/dispatch.h" - -namespace triton { -namespace ir { - - -[[ noreturn ]] void throw_unreachable(std::string key) { - throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. " - "This is likely a bug on our side."); -} - -//===----------------------------------------------------------------------===// -// Programming Model -//===----------------------------------------------------------------------===// - -ir::value *dispatch::program_id(int axis, ir::builder *builder) { - return builder->create_get_program_id(axis); -} - -ir::value *dispatch::num_programs(int axis, ir::builder *builder) { - return builder->create_get_num_programs(axis); -} - -//===----------------------------------------------------------------------===// -// Implicit Casting Utilities -//===----------------------------------------------------------------------===// - -ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){ - int a_rank = a_ty->get_integer_bitwidth(); - int b_rank = b_ty->get_integer_bitwidth(); - auto a_sn = a_ty->get_integer_signedness(); - auto b_sn = b_ty->get_integer_signedness(); - // Rules for signedness taken from "Usual arithmetic conversions" on - // https://en.cppreference.com/w/c/language/conversion. - if (a_sn == b_sn) { - return a_rank > b_rank ? a_ty : b_ty; - } else if (a_sn == signedness::UNSIGNED) { - return a_rank >= b_rank ? a_ty : b_ty; - } else if (b_sn == signedness::UNSIGNED) { - return b_rank >= a_rank ? b_ty : a_ty; - } else { - throw_unreachable("integer_promote"); - } -} - -enum class DivOrMod { NO, YES }; - -ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) { - context &ctx = a_ty->get_context(); - // 1) if one operand is double, the other is implicitly - // converted to double - if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty()) - return type::get_fp64_ty(ctx); - // 2) if one operand is float, the other is implicitly - // converted to float - if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty()) - return type::get_fp32_ty(ctx); - // 3 ) if one operand is half, the other is implicitly converted to half - // unless we're doing / or %, which do not exist natively in PTX for fp16. - if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) { - if (div_or_mod == DivOrMod::YES) { - return type::get_fp32_ty(ctx); - } else { - return type::get_fp16_ty(ctx); - } - } - if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty()) - throw_unreachable("computation_type"); - // 4 ) both operands are integer and undergo - // integer promotion - if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) { - throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - return integer_promote(a_ty, b_ty); -} - -//===----------------------------------------------------------------------===// -// Binary Operators -//===----------------------------------------------------------------------===// - -void throw_incompatible_types(ir::type* type_a, ir::type* type_b) { - throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr()); -} - -void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){ - - if(type_a->is_pointer_ty()){ - if(!allow_ptr_a) - throw_incompatible_types(type_a, type_b); - // T* + U* with T != U - if(type_b->is_pointer_ty() && (type_a != type_b)) - throw_incompatible_types(type_a, type_b); - // T* + float - if(type_b->is_floating_point_ty()) - throw_incompatible_types(type_a, type_b); - } -} - -void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder, - bool allow_lhs_ptr = false, bool allow_rhs_ptr = false, - bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) { - // implicit broadcasting - std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder); - // implicit typecasting - ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty(); - ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty(); - check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr); - check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr); - if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) { - ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod); - lhs = dispatch::cast(lhs, ret_sca_ty, builder); - rhs = dispatch::cast(rhs, ret_sca_ty, builder); - } -} - -ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, true); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // offset + ptr - // ptr + offset - if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty()) - std::swap(input, other); - if (input_scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {other}); - // float + float - else if (input_scalar_ty->is_floating_point_ty()) - return builder->create_fadd(input, other); - // int + int - else if (input_scalar_ty->is_integer_ty()) - return builder->create_add(input, other); - throw_unreachable("add"); -} - -ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, true, false); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // ptr - offset - if (scalar_ty->is_pointer_ty()) - return builder->create_gep(input, {dispatch::minus(other, builder)}); - // float + float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fsub(input, other); - // int + int - else if (scalar_ty->is_integer_ty()) - return builder->create_sub(input, other); - throw_unreachable("sub"); -} - -ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float * float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fmul(input, other); - // int * int - else if (scalar_ty->is_integer_ty()) - return builder->create_mul(input, other); - throw_unreachable("mul"); -} - -ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float / int - if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty()) - other = cast(other, input_scalar_ty, builder); - // int / float - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty()) - input = cast(input, other_scalar_ty, builder); - // int / int (cast to float32) - else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - input = cast(input, builder->get_float_ty(), builder); - other = cast(other, builder->get_float_ty(), builder); - } - // float / float (cast to highest exponent type) - else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){ - if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width()) - other = cast(other, input_scalar_ty, builder); - else - input = cast(input, other_scalar_ty, builder); - } - // unreachable - else - throw_unreachable("div"); - return builder->create_fdiv(input, other); -} - -ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){ - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){ - ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty); - input = dispatch::cast(input, ret_ty, builder); - other = dispatch::cast(other, ret_ty, builder); - if (ret_ty->is_integer_signed()) { - return builder->create_sdiv(input, other); - } else { - return builder->create_udiv(input, other); - } - } - throw_unreachable("floordiv"); -} - -ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){ - ir::type *input_scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty()) - throw semantic_error("both operands of fdiv must have floating point scalar type"); - binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES); - ir::value* ret = builder->create_fdiv(input, other); - if(ir::binary_operator* binop = dynamic_cast(ret)) - binop->set_fdiv_ieee_rounding(ieee_rounding->get_value()); - return ret; -} - -ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - ir::type *other_scalar_ty = other->get_type()->get_scalar_ty(); - // float % int - if (scalar_ty->is_floating_point_ty()) - return builder->create_frem(input, other); - // int % int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) { - throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness."); - } - if (scalar_ty->is_integer_signed()) { - return builder->create_srem(input, other); - } else { - return builder->create_urem(input, other); - } - } - throw_unreachable("mod"); -} - - -void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) { - binary_op_type_checking(input, other, builder, false, false, false); - ir::type *input_sca_ty = input->get_type()->get_scalar_ty(); - ir::type *other_sca_ty = other->get_type()->get_scalar_ty(); - if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty()) - throw_incompatible_types(input_sca_ty, other_sca_ty); - ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty); - if (ret_sca_ty != input_sca_ty) - input = dispatch::cast(input, ret_sca_ty, builder); - if (ret_sca_ty != other_sca_ty) - other = dispatch::cast(other, ret_sca_ty, builder); -} - -ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_and(input, other); -} - -ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_or(input, other); -} - - -ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_xor(input, other); -} - - -ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_lshr(input, other); -} - - -ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) { - bitwise_op_type_checking(input, other, builder); - return builder->create_shl(input, other); -} - -//===----------------------------------------------------------------------===// -// Unary Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::plus(ir::value *input, ir::builder *) { - return input; -} - -ir::value *dispatch::minus(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty()) - throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")"); - ir::value *_0 = ir::constant::get_null_value(input_sca_ty); - return dispatch::sub(_0, input, builder); -} - -ir::value *dispatch::invert(ir::value *input, ir::builder *builder) { - ir::type* input_sca_ty = input->get_type()->get_scalar_ty(); - if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty()) - throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")"); - ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty); - return dispatch::xor_(input, _1, builder); -} - - -//===----------------------------------------------------------------------===// -// Comparison Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float > float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGT(input, other); - // int > int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGT(input, other); - } else { - return builder->create_icmpUGT(input, other); - } - } - throw_unreachable("greater_than"); -} - -ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float >= float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGE(input, other); - // int >= int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSGE(input, other); - } else { - return builder->create_icmpUGE(input, other); - } - } - throw_unreachable("greater_equal"); -} - -ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLT(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLT(input, other); - } else { - return builder->create_icmpULT(input, other); - } - } - throw_unreachable("less_than"); -} - -ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLE(input, other); - // int < int - else if (scalar_ty->is_integer_ty()) { - if (scalar_ty->is_integer_signed()) { - return builder->create_icmpSLE(input, other); - } else { - return builder->create_icmpULE(input, other); - } - } - throw_unreachable("less_equal"); -} - -ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOEQ(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpEQ(input, other); - throw_unreachable("equal"); -} - -ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) { - binary_op_type_checking(input, other, builder); - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpUNE(input, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpNE(input, other); - throw_unreachable("equal"); -} - -//===----------------------------------------------------------------------===// -// Block Creation -//===----------------------------------------------------------------------===// - -ir::value* dispatch::arange(int start, int end, ir::builder *builder) { - return builder->get_range(start, end); -} - -ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) { - ir::value *_0 = ir::constant::get_null_value(dtype); - return builder->create_splat(_0, shape); -} - -//===----------------------------------------------------------------------===// -// Shape Manipulation -//===----------------------------------------------------------------------===// - - -ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) { - unsigned numel = 1; - for(unsigned s: dst_shape) numel *= s; - if(input->get_type()->get_tile_num_elements() != numel) - throw semantic_error("cannot reshape block of different shape"); - return builder->create_reshape(input, dst_shape); -} - -ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { - return builder->create_cat(lhs, rhs); -} - -ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { - if (!input->get_type()->is_block_ty()) - return builder->create_splat(input, shape); - auto src_shape = input->get_type()->get_block_shapes(); - if (src_shape.size() != shape.size()) - throw std::runtime_error("Cannot broadcast"); - if(shape == src_shape) - return input; - return builder->create_broadcast(input, shape); -} - -std::tuple dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - - // make_shape_compatible(block, scalar) - if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) - rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); - // make_shape_compatible(scalar, block) - else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) - lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); - // make_shape_compatible(block, block) - else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { - auto lhs_shape = lhs_ty->get_block_shapes(); - auto rhs_shape = rhs_ty->get_block_shapes(); - if (lhs_shape.size() != rhs_shape.size()) - throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); - ir::type::block_shapes_t ret_shape; - for (size_t i = 0; i < lhs_shape.size(); ++i) { - unsigned left = lhs_shape[i]; - unsigned right = rhs_shape[i]; - if (left == 1) - ret_shape.push_back(right); - else if (right == 1) - ret_shape.push_back(left); - else if (left == right) - ret_shape.push_back(left); - else - throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + - ": " + std::to_string(left) + " and " + std::to_string(right)); - } - if (lhs_shape != ret_shape) - lhs = builder->create_broadcast(lhs, ret_shape); - if (rhs_shape != ret_shape) - rhs = builder->create_broadcast(rhs, ret_shape); - } - return std::make_tuple(lhs, rhs); -} - -ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){ - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty()) - return cast(input, dst_ty, builder); - // Bitcast - int src_bits = src_sca_ty->get_primitive_size_in_bits(); - int dst_bits = dst_sca_ty->get_primitive_size_in_bits(); - if( src_bits!= dst_bits) - throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) + - "to data-type of size " + std::to_string(dst_bits)); - return builder->create_cast(ir::BitCast, input, dst_ty); -} - -ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) { - ir::type *src_ty = input->get_type(); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - if(src_ty == dst_ty) - return input; - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - // - if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) || - (dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){ - return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder); - } - // FP Truncation - bool truncate_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); - if (truncate_fp) - return builder->create_fp_trunc(input, dst_ty); - // FP Extension - bool ext_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); - if (ext_fp) - return builder->create_fp_ext(input, dst_ty); - // Int cast - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - (src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() || - src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) { - bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty(); - return builder->create_int_cast(input, dst_ty, sign_extend); - } - // Float -> Int - if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){ - if(dst_sca_ty->is_bool_ty()) - return builder->create_fp_to_ui(input, dst_ty); - else - return builder->create_fp_to_si(input, dst_ty); - } - // int -> Float - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){ - if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed()) - return builder->create_ui_to_fp(input, dst_ty); - else - return builder->create_si_to_fp(input, dst_ty); - } - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){ - int bitwidth = dst_sca_ty->get_integer_bitwidth(); - if(bitwidth == 64) - return builder->create_cast(ir::PtrToInt, input, dst_ty); - if(bitwidth == 1) - return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder), - builder->get_int64(0), - builder); - } - if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::IntToPtr, input, dst_ty); - // Ptr -> Ptr - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::BitCast, input, dst_ty); - // * -> Bool - if (dst_sca_ty->is_bool_ty()) { - if (src_sca_ty->is_pointer_ty()) - input = cast(input, builder->get_int64_ty(), builder); - ir::value *other = builder->get_int64(0); - if (src_ty->is_bool_ty()) - other = builder->create_splat(other, src_ty->get_block_shapes()); - return builder->create_icmpNE(input, other); - } - throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); -} - -//===----------------------------------------------------------------------===// -// Memory Operators -//===----------------------------------------------------------------------===// - -ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - if(other) - other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder); - } - if(other) - other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cache modifier - load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default - if (!cache_modifier.empty()) { - if (cache_modifier == ".ca") - cache = load_inst::CA; - else if (cache_modifier == ".cg") - cache = load_inst::CG; - else - throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported"); - } - // eviction policy - load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default - if(!eviction_policy.empty()){ - if (eviction_policy == "evict_last") - eviction = load_inst::EVICT_LAST; - else if(eviction_policy == "evict_first") - eviction = load_inst::EVICT_FIRST; - else - throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported"); - } - - - if (!mask && !other) - return builder->create_load(ptr, cache, eviction, is_volatile); - if (!mask) - throw std::runtime_error("`other` cannot be provided without `mask`"); - auto shape = ptr->get_type()->get_block_shapes(); - if(!other){ - other = ir::undef_value::get(elt_ty); - if(ptr->get_type()->is_block_ty()) - other = builder->create_splat(other, ptr->get_type()->get_block_shapes()); - } - return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile); -} - -ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) { - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()) - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - if(mask) - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - ir::type *ptr_ty = ptr->get_type()->get_scalar_ty(); - ir::type *elt_ty = ptr_ty->get_pointer_element_ty(); - // treat bool* as int8* - if(elt_ty == builder->get_int1_ty()){ - elt_ty = builder->get_int8_ty(); - ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space()); - ptr = dispatch::cast(ptr, ptr_ty, builder); - } - // cast to target data-type - val = dispatch::cast(val, elt_ty, builder); - if (!mask) - return builder->create_store(ptr, val); - if(!mask->get_type()->get_scalar_ty()->is_bool_ty()) - throw semantic_error("Mask must have boolean scalar type"); - return builder->create_masked_store(ptr, val, mask); -} - -ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){ - return builder->create_atomic_cas(ptr, cmp, val); -} - -void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){ - if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty()) - throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr()); - if(ptr->get_type()->is_block_ty()){ - if(mask){ - mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder); - } - if(val){ - val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder); - } - } - val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder); - if(!mask){ - mask = builder->get_int1(true); - if(ptr->get_type()->is_block_ty()) - mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes()); - } -} - -ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_max for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask); - } - } - // for float - // return atomic_smax(i_ptr, i_val) if val >= 0 - // return atomic_umin(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - // direct call to atomic_min for integers - if(sca_ty->is_integer_ty()) { - if (sca_ty->is_integer_signed()) { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask); - } else { - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask); - } - } - // for float - // return atomic_smin(i_ptr, i_val) if val >= 0 - // return atomic_umax(i_ptr, i_val) if val < 0 - ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder); - ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder); - ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder); - ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder); - ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder)); - ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder)); - return where(pos, pos_ret, neg_ret, builder); -} - -ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add; - return builder->create_atomic_rmw(op, ptr, val, mask); -} - -ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask); -} - -ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask); -} - -ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask); -} - -ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){ - atom_red_typechecking(ptr, val, mask, builder); - ir::type* sca_ty = val->get_type()->get_scalar_ty(); - return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask); -} - -//===----------------------------------------------------------------------===// -// Linear Algebra -//===----------------------------------------------------------------------===// - -ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) { - ir::value *_0 = nullptr; - if (lhs->get_type()->is_int_or_tileint_ty()) - _0 = builder->get_int32(0); - else - _0 = builder->get_float32(0); - unsigned M = lhs->get_type()->get_block_shapes()[0]; - unsigned N = rhs->get_type()->get_block_shapes()[1]; - _0 = builder->create_splat(_0, {M, N}); - bool _allow_tf32 = allow_tf32->get_value() != 0; - return builder->create_dot(lhs, rhs, _0, _allow_tf32); -} - - -//===----------------------------------------------------------------------===// -// Indexing -//===----------------------------------------------------------------------===// - -ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){ - condition = dispatch::cast(condition, builder->get_int1_ty(), builder); - if(condition->get_type()->is_block_ty()){ - x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder); - y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder); - } - ir::type* x_ty = x->get_type()->get_scalar_ty(); - ir::type* y_ty = y->get_type()->get_scalar_ty(); - ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO); - x = dispatch::cast(x, ty, builder); - y = dispatch::cast(y, ty, builder); - return builder->create_select(condition, x, y); -} - - -//===----------------------------------------------------------------------===// -// Reductions -//===----------------------------------------------------------------------===// - -ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, - ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - // input is extended to 32-bits if necessary - // this increases numerical accuracy and can be done pretty much for free - // on GPUs - if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32) - input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder); - if (scalar_ty->is_floating_point_ty()) - return builder->create_reduce(input, FLOAT_OP, axis); - else if (scalar_ty->is_integer_ty()) - return builder->create_reduce(input, INT_OP, axis); - throw_unreachable(name); -} - -ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); -} - -ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); -} - -ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); -} - -ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - if (!scalar_ty->is_integer_ty()) - throw semantic_error("xor_sum only supported for integers"); - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR); -} - - -//===----------------------------------------------------------------------===// -// Math -//===----------------------------------------------------------------------===// - -ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { - binary_op_type_checking(x, y, builder); - return builder->insert(umulhi_inst::create(x, y)); -} - -ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { - return builder->create_exp(x); -} - -ir::value *dispatch::log(ir::value *x, ir::builder *builder) { - return builder->create_log(x); -} - -ir::value *dispatch::cos(ir::value *x, ir::builder *builder) { - return builder->create_cos(x); -} - -ir::value *dispatch::sin(ir::value *x, ir::builder *builder) { - return builder->create_sin(x); -} - -ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) { - return builder->create_sqrt(x); -} - - -// - -ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("multiple_of"); - i->set_metadata(ir::metadata::multiple_of, value); - return i; -} - -ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){ - ir::instruction* i = dynamic_cast(x); - if(!i) - throw_unreachable("max_contiguous"); - i->set_metadata(ir::metadata::max_contiguous, value); - return i; -} - -ir::value *dispatch::debug_barrier(ir::builder *builder) { - return builder->create_barrier(); -} - - -} -} diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index c225b315f..39bd945bc 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -312,8 +312,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed, unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : - (arg_bits > dst_bits ? cast_op_t::Trunc : - (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); + (arg_bits > dst_bits ? cast_op_t::Trunc : + (is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); return create(op, arg, ty, name, next); } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 33b39de3a..a37d3048f 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -9,146 +9,6 @@ namespace triton{ namespace ir{ -/* Module */ -module::module(const std::string &name, builder &builder) - : name_(name), builder_(builder) { - sealed_blocks_.insert(nullptr); -} - -ir::builder& module::get_builder() { - return builder_; -} - -void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){ - values_[val_key_t{name, block}] = value; - auto it = metadatas_.find(name); - if(auto *x = dynamic_cast(value)) - if(it != metadatas_.end()){ - x->set_metadata(it->second.first, it->second.second); - } -// value->set_name(name); -} - -void module::set_value(const std::string& name, ir::value *value){ - return set_value(name, builder_.get_insert_block(), value); -} - -void module::set_const(const std::string& name){ - const_.insert(name); -} - -void module::set_continue_fn(std::function fn) { - continue_fn_ = fn; -} - -std::function module::get_continue_fn() { - return continue_fn_; -} - -ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){ - basic_block::iterator insert = block->get_first_non_phi(); - if(insert != block->end()){ - builder_.set_insert_point(insert); - } - ir::phi_node *res = builder_.create_phi(ty, num_values); - if(insert != block->end()) - builder_.set_insert_point(block); - return res; -} - -ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ - // find non-self references - std::set non_self_ref; - std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()), - [phi](ir::value* op){ return op != phi && op; }); - // non-trivial - if(non_self_ref.size() != 1) - return phi; - // unique value or self-reference - ir::value *same = *non_self_ref.begin(); - assert(same != nullptr); - phi->replace_all_uses_with(same); - phi->erase_from_parent(); - std::set users = phi->get_users(); - for(ir::user* u: users) - if(auto *uphi = dynamic_cast(u)) - if(uphi != phi) - try_remove_trivial_phis(uphi); - return same; -} - - -ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){ - // already initialized - if(phi->get_num_operands()) - return phi; - ir::basic_block *block = phi->get_parent(); - for(ir::basic_block *pred: block->get_predecessors()){ - ir::value *value = get_value(name, pred); - phi->add_incoming(value, pred); - } - return phi; -} - -ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) { - ir::value *result; - bool is_const = const_.find(name) != const_.end(); - auto &preds = block->get_predecessors(); - ir::type *ty = types_.at(name); - if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ - incomplete_phis_[block][name] = make_phi(ty, 1, block); - result = (ir::value*)incomplete_phis_[block][name]; - } - else if(preds.size() <= 1){ - bool has_pred = preds.size(); - result = get_value(name, has_pred?preds.front():nullptr); - } - else{ - ir::phi_node* phi = make_phi(ty, 1, block); - set_value(name, block, phi); - result = add_phi_operands(name, phi); - if(auto *phi = dynamic_cast(result)) - result = try_remove_trivial_phis(phi); - } - if(auto *phi = dynamic_cast(result)){ - result = try_remove_trivial_phis(phi); - } - set_value(name, block, result); - return result; -} - -ir::value *module::get_value(const std::string& name, ir::basic_block *block) { - ir::basic_block* save_block = builder_.get_insert_block(); - ir::basic_block::iterator save_pt = builder_.get_insert_point(); - val_key_t key(name, block); - if(values_.find(key) != values_.end()){ - return values_.at(key); - } - ir::value *result = get_value_recursive(name, block); - builder_.set_insert_point(save_block); - if(save_pt != save_block->end()) - builder_.set_insert_point(save_pt); - return result; -} - -ir::value *module::get_value(const std::string& name) { - return get_value(name, builder_.get_insert_block()); -} - -const std::string& module::get_name() { - return name_; -} - -void module::seal_block(ir::basic_block *block){ - for(auto &x: incomplete_phis_[block]){ - add_phi_operands(x.first, x.second); - if(get_value(x.first) == x.second) - set_value(x.first, try_remove_trivial_phis(x.second)); - } - sealed_blocks_.insert(block); - incomplete_phis_[block].clear(); -} - /* functions */ function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *&fn = (function*&)symbols_[name]; diff --git a/lib/ir/type.cc b/lib/ir/type.cc index 7e4e4e5d7..056ae99e6 100644 --- a/lib/ir/type.cc +++ b/lib/ir/type.cc @@ -36,16 +36,6 @@ unsigned type::get_primitive_size_in_bits() const { unsigned type::get_integer_bitwidth() const { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } -signedness type::get_integer_signedness() const -{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); } - -bool type::is_integer_signed() const { - if (id_ != IntegerTyID) { - throw std::logic_error("type is " + repr() + ", not integer"); - } - return ((integer_type*)(this))->get_signedness() == signedness::SIGNED; -} - unsigned type::get_tile_bitwidth() const { return ((block_type*)(this))->get_bitwidth(); } @@ -145,10 +135,6 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } -integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; } -integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; } -integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; } -integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 9e53cc341..b66761ec3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,7 +3,6 @@ #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" -#include "triton/ir/dispatch.h" #include "triton/ir/enums.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -12,10 +11,12 @@ #include #include #include +#include #include #include "Python.h" #include #include +#include #include #include "llvm/IR/Module.h" #include "llvm/IR/LegacyPassManager.h" @@ -541,84 +542,6 @@ void init_triton_codegen(py::module &&m) { }, py::return_value_policy::take_ownership); } -/*****************************************************************************/ -/* User-facing language features */ -/*****************************************************************************/ - -void init_triton_frontend(py::module &&m) { - using ret = py::return_value_policy; - - // programming model - m.def("program_id", &ir::dispatch::program_id, ret::reference); - m.def("num_programs", &ir::dispatch::num_programs, ret::reference); - // binary - m.def("add", &ir::dispatch::add, ret::reference); - m.def("sub", &ir::dispatch::sub, ret::reference); - m.def("mul", &ir::dispatch::mul, ret::reference); - m.def("truediv", &ir::dispatch::truediv, ret::reference); - m.def("floordiv", &ir::dispatch::floordiv, ret::reference); - m.def("fdiv", &ir::dispatch::fdiv, ret::reference); - m.def("mod", &ir::dispatch::mod, ret::reference); - m.def("and_", &ir::dispatch::and_, ret::reference); - m.def("or_", &ir::dispatch::or_, ret::reference); - m.def("xor_", &ir::dispatch::xor_, ret::reference); - m.def("lshr", &ir::dispatch::lshr, ret::reference); - m.def("shl", &ir::dispatch::shl, ret::reference); - // unary - m.def("plus", &ir::dispatch::plus, ret::reference); - m.def("minus", &ir::dispatch::minus, ret::reference); - m.def("invert", &ir::dispatch::invert, ret::reference); - // comparison - m.def("greater_than", &ir::dispatch::greater_than, ret::reference); - m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference); - m.def("less_than", &ir::dispatch::less_than, ret::reference); - m.def("less_equal", &ir::dispatch::less_equal, ret::reference); - m.def("equal", &ir::dispatch::equal, ret::reference); - m.def("not_equal", &ir::dispatch::not_equal, ret::reference); - // block creation - m.def("arange", &ir::dispatch::arange, ret::reference); - m.def("zeros", &ir::dispatch::zeros, ret::reference); - // type manipuatation - m.def("cat", &ir::dispatch::cat, ret::reference); - m.def("reshape", &ir::dispatch::reshape, ret::reference); - typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); - typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); - m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference); - m.def("bitcast", &ir::dispatch::bitcast, ret::reference); - m.def("cast", &ir::dispatch::cast, ret::reference); - // memory - m.def("load", &ir::dispatch::load, ret::reference); - m.def("store", &ir::dispatch::store, ret::reference); - m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference); - m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference); - m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference); - m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference); - m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference); - m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference); - m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference); - m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference); - // linear algebra - m.def("dot", &ir::dispatch::dot, ret::reference); - // indexing - m.def("where", &ir::dispatch::where, ret::reference); - // reduction - m.def("min", &ir::dispatch::min, ret::reference); - m.def("max", &ir::dispatch::max, ret::reference); - m.def("sum", &ir::dispatch::sum, ret::reference); - m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference); - // math - m.def("umulhi", &ir::dispatch::umulhi, ret::reference); - m.def("exp", &ir::dispatch::exp, ret::reference); - m.def("log", &ir::dispatch::log, ret::reference); - m.def("cos", &ir::dispatch::cos, ret::reference); - m.def("sin", &ir::dispatch::sin, ret::reference); - m.def("sqrt", &ir::dispatch::sqrt, ret::reference); - // internal (debugging only) - m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference); - m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference); - m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference); -} /*****************************************************************************/ /* Python bindings for triton::ir */ @@ -628,16 +551,86 @@ void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; + py::enum_(m, "CACHE_MODIFIER") + .value("NONE", ir::load_inst::NONE) + .value("CA", ir::load_inst::CA) + .value("CG", ir::load_inst::CG) + .export_values(); + + py::enum_(m, "EVICTION_POLICY") + .value("NORMAL", ir::load_inst::NORMAL) + .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST) + .value("EVICT_LAST", ir::load_inst::EVICT_LAST) + .export_values(); + + py::enum_(m, "REDUCE_OP") + .value("ADD", ir::reduce_inst::ADD) + .value("FADD", ir::reduce_inst::FADD) + .value("MIN", ir::reduce_inst::MIN) + .value("MAX", ir::reduce_inst::MAX) + .value("FMIN", ir::reduce_inst::FMIN) + .value("FMAX", ir::reduce_inst::FMAX) + .value("XOR", ir::reduce_inst::XOR); + + py::enum_(m, "ATOMIC_OP") + .value("ADD", ir::atomic_rmw_op_t::Add) + .value("FADD", ir::atomic_rmw_op_t::FAdd) + .value("AND", ir::atomic_rmw_op_t::And) + .value("OR", ir::atomic_rmw_op_t::Or) + .value("XOR", ir::atomic_rmw_op_t::Xor) + .value("XCHG", ir::atomic_rmw_op_t::Xchg) + .value("MAX", ir::atomic_rmw_op_t::Max) + .value("MIN", ir::atomic_rmw_op_t::Min) + .value("UMIN", ir::atomic_rmw_op_t::UMin) + .value("UMAX", ir::atomic_rmw_op_t::UMax); + py::class_(m, "context") .def(py::init<>()); - auto value = py::class_(m, "value"); - value.def_property("name", &ir::value::get_name, &ir::value::set_name); - value.def_property_readonly("type", &ir::value::get_type); + py::class_(m, "value") + .def("multiple_of", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::multiple_of, val); + } else + throw std::runtime_error("multiple_of"); + }) + .def("max_contiguous", [](ir::value *self, int val) { + if (auto *instr = dynamic_cast(self)) { + instr->set_metadata(ir::metadata::max_contiguous, val); + } else + throw std::runtime_error("max_contiguous"); + }) + .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) { + if (auto *instr = dynamic_cast(self)) + instr->set_fdiv_ieee_rounding(val); + else + throw std::runtime_error("set_fdiv_ieee_rounding"); + }) + .def("is_phi", [](ir::value *self) { + if (auto *pn = dynamic_cast(self)) + return true; + return false; + }) + .def("ops", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) { + return instr->ops(); + } + throw std::runtime_error("cannot use ops()"); + }) + .def("replace_all_uses_with", &ir::value::replace_all_uses_with) + .def("erase_from_parent", [](ir::value *self) { + if (auto *instr = dynamic_cast(self)) + return instr->erase_from_parent(); + throw std::runtime_error("cannot use erase_from_parent"); + }) + .def_property("name", &ir::value::get_name, &ir::value::set_name) + .def_property_readonly("type", &ir::value::get_type); py::class_(m, "user"); - py::class_(m, "constant"); + py::class_(m, "constant") + .def("get_null_value", &ir::constant::get_null_value, ret::reference) + .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference); py::class_(m, "undef") .def("get", &ir::undef_value::get, ret::reference); @@ -648,16 +641,17 @@ void init_triton_ir(py::module &&m) { .def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); py::class_(m, "constant_float") - .def_property_readonly("value", &ir::constant_fp::get_value); + .def_property_readonly("value", &ir::constant_fp::get_value) + .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); - py::class_(m, "instruction"); - py::class_(m, "phi_node"); + py::class_(m, "instruction") + .def("get_parent", [](ir::instruction *self) { + return self->get_parent(); + }, ret::reference); + py::class_(m, "phi_node") + .def("add_incoming", &ir::phi_node::add_incoming); py::class_(m, "type") - .def("is_ptr", &ir::type::is_pointer_ty) - .def("is_int", static_cast(&ir::type::is_integer_ty)) - .def("is_floating", &ir::type::is_floating_point_ty) - .def("is_block", &ir::type::is_block_ty) .def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference) @@ -672,34 +666,38 @@ void init_triton_ir(py::module &&m) { .def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference) - .def("get_uint8", &ir::type::get_uint8_ty, ret::reference) - .def("get_uint16", &ir::type::get_uint16_ty, ret::reference) - .def("get_uint32", &ir::type::get_uint32_ty, ret::reference) - .def("get_uint64", &ir::type::get_uint64_ty, ret::reference) + .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) + .def("get_block_shapes", &ir::type::get_block_shapes) + + .def("is_ptr", &ir::type::is_pointer_ty) + .def("is_int", static_cast(&ir::type::is_integer_ty)) + .def("is_floating", &ir::type::is_floating_point_ty) + .def("is_block", &ir::type::is_block_ty) .def("is_void", &ir::type::is_void_ty) + .def("is_bool", &ir::type::is_bool_ty) .def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp16", &ir::type::is_fp16_ty) .def("is_bf16", &ir::type::is_bf16_ty) .def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp64", &ir::type::is_fp64_ty) - .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) - .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) - .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) - .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) - .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) - .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) - .def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); }) - .def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); }) - .def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); }) + .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) + .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) + .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) + .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) + .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) + .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) .def("repr", &ir::type::repr) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("scalar", &ir::type::get_scalar_ty) - .def_property_readonly("context", &ir::type::get_context, ret::reference); + .def_property_readonly("context", &ir::type::get_context, ret::reference) + .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth) + .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits); py::class_(m, "pointer_type") - .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); + .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) + .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference); py::class_(m, "function_type"); py::class_(m, "integer_type"); @@ -709,16 +707,15 @@ void init_triton_ir(py::module &&m) { py::class_(m, "module") .def(py::init()) - .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) - .def("seal_block", &ir::module::seal_block) - .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) - .def("set_type", &ir::module::set_type) - .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) - .def("get_values", &ir::module::get_values, ret::reference) - .def("set_values", &ir::module::set_values) - .def("get_types", &ir::module::get_types, ret::reference) - .def("set_types", &ir::module::set_types) - .def_property_readonly("builder", &ir::module::get_builder, ret::reference); + .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { + const auto metadatas = self->get_metadatas(); + auto it = metadatas.find(name); + if (it != metadatas.end()) + if (auto *instr = dynamic_cast(value)) { + instr->set_metadata(it->second.first, it->second.second); + } + }) + .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); using eattr = ir::attribute_kind_t; py::enum_(m, "attribute_kind") @@ -742,6 +739,13 @@ void init_triton_ir(py::module &&m) { py::class_(m, "basic_block") .def("create", &ir::basic_block::create, ret::reference) + .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) + .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* { + ir::basic_block::iterator it = self->get_first_non_phi(); + if (it == self->end()) + return nullptr; + return *it; + }, ret::reference) .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); py::class_(m, "builder", py::dynamic_attr()) @@ -752,17 +756,162 @@ void init_triton_ir(py::module &&m) { .def("br", &ir::builder::create_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference) + // insertion block/point, insert points are represented as (*bb, *instr) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) - // constants + .def("get_insert_point", [](ir::builder *self) { + ir::basic_block *bb = self->get_insert_block(); + ir::basic_block::iterator it = self->get_insert_point(); + ir::instruction *instr = it == bb->end() ? nullptr : *it; + return std::make_pair(bb, instr); + }, ret::reference) + .def("set_insert_point", [](ir::builder *self, std::pair pt) { + ir::basic_block *bb = pt.first; + ir::instruction *instr = pt.second; + if (instr) { + if (bb != instr->get_parent()) + throw std::runtime_error("invalid insertion point, instr not in bb"); + self->set_insert_point(instr); + } else { + assert(bb); + self->set_insert_point(bb); + } + }) + // Constants .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", &ir::builder::get_int32, ret::reference) - .def("get_int64", &ir::builder::get_int64, ret::reference) - .def("get_uint32", &ir::builder::get_uint32, ret::reference) - .def("get_uint64", &ir::builder::get_uint64, ret::reference) + .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) + .def("get_uint32", &ir::builder::get_int32, ret::reference) + .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) + .def("get_uint64", &ir::builder::get_int64, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference) - .def("get_range", &ir::builder::get_range, ret::reference); + .def("get_range", &ir::builder::get_range, ret::reference) + // Types + .def("get_void_ty", &ir::builder::get_void_ty, ret::reference) + .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference) + .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference) + .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference) + .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference) + .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference) + .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference) + .def("get_half_ty", &ir::builder::get_half_ty, ret::reference) + .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference) + .def("get_float_ty", &ir::builder::get_float_ty, ret::reference) + .def("get_double_ty", &ir::builder::get_double_ty, ret::reference) + // terminator instructions + .def("create_br", &ir::builder::create_br, ret::reference) + .def("create_cond_br", &ir::builder::create_cond_br, ret::reference) + .def("create_ret_void", &ir::builder::create_ret_void, ret::reference) + // Cast instructions + .def("create_bitcast", &ir::builder::create_bitcast, ret::reference) + .def("create_cast", &ir::builder::create_cast, ret::reference) + .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference) + .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference) + .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference) + .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference) + .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference) + .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference) + .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference) + .def("create_int_cast", &ir::builder::create_int_cast, ret::reference) + .def("create_downcast", &ir::builder::create_downcast, ret::reference) + // phi + .def("create_phi", &ir::builder::create_phi, ret::reference) + // Binary instructions + .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference) + .def("create_fmul", &ir::builder::create_fmul, ret::reference) + .def("create_fdiv", &ir::builder::create_fdiv, ret::reference) + .def("create_frem", &ir::builder::create_frem, ret::reference) + .def("create_fadd", &ir::builder::create_fadd, ret::reference) + .def("create_fsub", &ir::builder::create_fsub, ret::reference) + .def("create_mul", &ir::builder::create_mul, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sdiv", &ir::builder::create_sdiv, ret::reference) + .def("create_udiv", &ir::builder::create_udiv, ret::reference) + .def("create_srem", &ir::builder::create_srem, ret::reference) + .def("create_urem", &ir::builder::create_urem, ret::reference) + .def("create_add", &ir::builder::create_add, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_sub", &ir::builder::create_sub, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_shl", &ir::builder::create_shl, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_lshr", &ir::builder::create_lshr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_ashr", &ir::builder::create_ashr, ret::reference, + py::arg("lhs"), py::arg("rhs"), + py::arg("has_nuw")=false, py::arg("has_nsw")=false) + // GEP + .def("create_gep", &ir::builder::create_gep, ret::reference) + // Comparison (int) + .def("create_icmp", &ir::builder::create_icmp, ret::reference) + .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference) + .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference) + .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference) + .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference) + .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference) + .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference) + .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference) + .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference) + .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference) + .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference) + // Comparison (float) + .def("create_fcmp", &ir::builder::create_fcmp, ret::reference) + .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference) + .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference) + .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference) + .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference) + .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference) + .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference) + .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference) + .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference) + .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference) + .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference) + .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference) + .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference) + // Logical + .def("create_and", &ir::builder::create_and, ret::reference) + .def("create_xor", &ir::builder::create_xor, ret::reference) + .def("create_or", &ir::builder::create_or, ret::reference) + // Input/Output + .def("create_load", &ir::builder::create_load, ret::reference) + .def("create_store", &ir::builder::create_store, ret::reference) + .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) + .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) + // Block instruction + .def("create_splat", &ir::builder::create_splat, ret::reference) + .def("create_reshape", &ir::builder::create_reshape, ret::reference) + .def("create_cat", &ir::builder::create_cat, ret::reference) + .def("create_broadcast", &ir::builder::create_broadcast, ret::reference) + // atomic + .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) + .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference) + + // Built-in instruction + .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) + .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) + .def("create_exp", &ir::builder::create_exp, ret::reference) + .def("create_cos", &ir::builder::create_cos, ret::reference) + .def("create_sin", &ir::builder::create_sin, ret::reference) + .def("create_log", &ir::builder::create_log, ret::reference) + .def("create_dot", &ir::builder::create_dot, ret::reference) + .def("create_trans", &ir::builder::create_trans, ret::reference) + .def("create_sqrt", &ir::builder::create_sqrt, ret::reference) + .def("create_reduce", &ir::builder::create_reduce, ret::reference) + .def("create_select", &ir::builder::create_select, ret::reference) + // Intrinsics + // These have no place in the IR, and hopefully they can be removed at some point + .def("create_umulhi", &ir::builder::create_umulhi, ret::reference) + .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference) + .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference) + .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference) + .def("create_barrier", &ir::builder::create_barrier, ret::reference) + .def("create_async_wait", &ir::builder::create_async_wait, ret::reference) + .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference); } void init_triton(py::module &m) { @@ -770,5 +919,4 @@ void init_triton(py::module &m) { init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_ir(std::move(subm.def_submodule("ir"))); - init_triton_frontend(std::move(subm.def_submodule("frontend"))); } diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index 1df3a0b49..f30b203bb 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -37,7 +37,7 @@ matmul_data = { (256, 256, 256): {'float16': 0.027}, (512, 512, 512): {'float16': 0.158}, (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.680}, + (2048, 2048, 2048): {'float16': 0.695}, (4096, 4096, 4096): {'float16': 0.831}, (8192, 8192, 8192): {'float16': 0.849}, # tall-skinny diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a49b47585..3561f7af4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,5 +1,4 @@ # flake8: noqa: F821,F841 -import copy import itertools import re from typing import Optional, Union @@ -585,7 +584,6 @@ def test_f8_f16_roundtrip(): f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - print(f16.dtype, f8_output.dtype) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) assert torch.all(f8_tensor == f8_output_tensor) @@ -993,27 +991,6 @@ def test_noop(device='cuda'): kernel[(1, )](x) -@pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), - (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') -]) -def test_value_specialization(value: int, value_type: str, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - x = torch.tensor([3.14159], device='cuda') - pgm = kernel[(1, )](value, x) - - # Parse out the type of the 'VALUE' parameter from the Triton IR. - triton_ir = pgm.asm['ttir'] - ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir) - ir_value_type = None if ir_value_match is None else ir_value_match.group(1) - assert ir_value_type == value_type - - @pytest.mark.parametrize( "value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 8ac01bcc8..d866d6983 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,4 +1,5 @@ import os +import re import shutil import pytest @@ -102,3 +103,30 @@ def test_specialize(mode): for i in [1, 2, 4, 8, 16, 32]: function[(1,)](x, i, BLOCK=512) assert counter == target + + +@pytest.mark.parametrize("value, value_type", [ + (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), + (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), + (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') +]) +def test_value_specialization(value: int, value_type: str, device='cuda') -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + cache_str = None + + def get_cache_str(*args, **kwargs): + nonlocal cache_str + cache_str = kwargs['key'].split('-') + triton.code_gen.JITFunction.cache_hook = get_cache_str + reset_tmp_dir() + x = torch.tensor([3.14159], device='cuda') + kernel[(1, )](value, x) + triton.code_gen.JITFunction.cache_hook = None + + cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) + spec_type = None if cache_str_match is None else cache_str_match.group(1) + assert spec_type == value_type diff --git a/python/triton/__init__.py b/python/triton/__init__.py index f9982939c..37ba46efc 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,7 +6,8 @@ __version__ = '2.0.0' # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret +from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ + JITFunction, Config, Autotuner, reinterpret from . import language from . import code_gen from . import testing diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 09254c967..a253e2c4c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast import builtins import functools @@ -11,7 +13,7 @@ import tempfile import textwrap import time import warnings -from typing import Dict +from typing import Dict, Optional, Set, Tuple, Union import torch from filelock import FileLock @@ -22,48 +24,13 @@ from .tools.disasm import extract class CodeGenerator(ast.NodeVisitor): - def get_value(self, name): - # search node.id in local scope - ret = None - if name in self.lscope: - ret = self.lscope[name] - # search node.id in global scope - elif name in self.gscope: - ret = self.gscope[name] - # search node.id in builtins - elif name in self.builtins: - ret = self.builtins[name] - else: - raise ValueError(f'{name} is not defined') - if isinstance(ret, triton.language.block): - handle = self.module.get_value(name) - return triton.language.block(handle) - return ret - - def set_value(self, name, value): - if isinstance(value, _triton.ir.value): - value = triton.language.block(value) - if isinstance(value, triton.language.block): - self.module.set_value(name, value.handle) - self.module.set_type(name, value.handle.type) - self.lscope[name] = value - - def is_triton_object(self, value): - return isinstance(value, triton.language.block) - - def visit_compound_statement(self, stmts): - for stmt in stmts: - self.last_ret = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) - def __init__(self, context, prototype, gscope, attributes, constants, kwargs): self.builder = _triton.ir.builder(context) self.module = _triton.ir.module('', self.builder) self.prototype = prototype self.gscope = gscope self.lscope = dict() + self.is_arg_lscope = dict() # name => is_arg: {str: bool} self.attributes = attributes self.constants = constants self.kwargs = kwargs @@ -77,6 +44,146 @@ class CodeGenerator(ast.NodeVisitor): 'isinstance': isinstance, 'getattr': getattr, } + # SSA-construction + # [name, bb] => triton.language.tensor + self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} + # bb => {name => phi} + self.incomplete_phis = {} + self.sealed_blocks: Set[_triton.ir.basic_block] = set() + + def get_value(self, name): + ''' This function: + 1. make sure `name` is defined + 2. if `name` is triton.language.tensor, get stored tensor by calling + `self._get_tensor()` + ''' + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]: + return self._get_tensor(name) + return ret + + def set_value(self, name: str, + value: Union[triton.language.tensor, triton.language.constexpr], + is_arg: bool = False) -> None: + ''' This function: + called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + # if this value is an argument, we don't need to create phis for it + self.is_arg_lscope[name] = is_arg + if isinstance(value, triton.language.tensor) and not is_arg: + self._set_value(name, self.builder.get_insert_block(), value) + + # + # SSA-construction + # + def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: + if not bb: + bb = self.builder.get_insert_block() + # local value numbering + if (name, bb) in self.lvalues: + return self.lvalues[(name, bb)] + # global value numbering + saved_insert_point = self.builder.get_insert_point() + result = self._get_tensor_recursive(name, bb) + self.builder.set_insert_point(saved_insert_point) + return result + + def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: + preds = bb.get_predecessors() + type = self.lscope[name].type + # some preds haven't been filled, create a phi as a proxy of the value + if bb not in self.sealed_blocks: + result = self._make_phi(type, len(preds), bb) + if bb in self.incomplete_phis: + self.incomplete_phis[bb][name] = result + else: + self.incomplete_phis[bb] = {name: result} + elif len(preds) == 1: + # one predecessor: no phi needed, try get value from pred + result = self._get_tensor(name, preds[0]) + else: # multiple preds + assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)' + phi = self._make_phi(type, len(preds), bb) + self._set_value(name, bb, phi) + result = self._add_phi_operands(name, phi) + self._set_value(name, bb, result) + return result + + # returns a new phi tensor, which encausulate an ir.phi_node + def _make_phi(self, + type: triton.language.dtype, + num_values: int, + bb: _triton.ir.basic_block) -> triton.language.tensor: + instr = bb.get_first_non_phi() + self.builder.set_insert_point((bb, instr)) + ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) + if instr: + self.builder.set_insert_block(bb) + return triton.language.tensor(ir_phi, type) + + # complete a phi node. (TODO: rename this as _complete_phis?) + # Note: since we try to remove tryival phi, the return tensor might not be a phi + def _add_phi_operands(self, name: str, + phi: triton.language.tensor) -> triton.language.tensor: + bb = phi.handle.get_parent() + for pred in bb.get_predecessors(): + v = self._get_tensor(name, pred) + phi.handle.add_incoming(v.handle, pred) + phi = self._try_remove_trivial_phi(phi) + return phi + + def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: + self.lvalues[(name, bb)] = value + # TODO: why we need this? + self.module.set_instr_metadata(name, value.handle) + + def _seal_block(self, bb: _triton.ir.basic_block): + # complete all incomplete phis + if bb in self.incomplete_phis: + for name, phi in self.incomplete_phis[bb].items(): + result = self._add_phi_operands(name, phi) + # it's possible that this phi is trivial + if self._get_tensor(name, bb).handle == phi.handle: + self._set_value(name, bb, result) + del self.incomplete_phis[bb] + self.sealed_blocks.add(bb) + + def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: + unique_handles = {op for op in phi.handle.ops() if op != phi.handle} + if len(unique_handles) != 1: # non-trivial phi + return phi + v = unique_handles.pop() + phi.handle.replace_all_uses_with(v) + phi.handle.erase_from_parent() + # TODO: remove trivial phis recursively + return triton.language.tensor(v, phi.type) + + def is_triton_tensor(self, value): + return isinstance(value, triton.language.tensor) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.last_ret = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + return stmts and isinstance(stmt, ast.Return) def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -113,7 +220,7 @@ class CodeGenerator(ast.NodeVisitor): if inline: pass else: - fn = self.module.get_or_insert_function(node.name, self.prototype) + fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -130,17 +237,17 @@ class CodeGenerator(ast.NodeVisitor): attr = _triton.ir.attribute(attr, self.attributes[i]) fn.add_attr(idx + 1, attr) fn.args[idx].name = arg_name - arg_values.append(fn.args[idx]) + arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) idx += 1 for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) + self.set_value(arg_name, arg_value, is_arg=True) if inline: self.visit_compound_statement(node.body) return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) - self.module.seal_block(entry) + self._seal_block(entry) self.builder.set_insert_block(entry) # visit function body self.visit_compound_statement(node.body) @@ -187,11 +294,12 @@ class CodeGenerator(ast.NodeVisitor): if not isinstance(values, tuple): values = [values] for name, value in zip(names, values): + # TODO: can we store constexpr here to support constant folding? # by default, constexpr are assigned into python variable if isinstance(value, triton.language.constexpr): value = value.value - if not isinstance(value, triton.language.block): - value = triton.language.core._to_ir(value, self.builder) + if not isinstance(value, triton.language.tensor): + value = triton.language.core._to_tensor(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -220,9 +328,9 @@ class CodeGenerator(ast.NodeVisitor): def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value fn = { ast.Add: '__add__', @@ -238,9 +346,9 @@ class CodeGenerator(ast.NodeVisitor): ast.BitOr: '__or__', ast.BitXor: '__xor__', }[type(node.op)] - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif self.is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -248,15 +356,15 @@ class CodeGenerator(ast.NodeVisitor): def visit_If(self, node): cond = self.visit(node.test) - if isinstance(cond, triton.language.block): + if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) current_bb = self.builder.get_insert_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - self.module.seal_block(then_bb) + self._seal_block(then_bb) if else_bb: - self.module.seal_block(else_bb) + self._seal_block(else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb) else: self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -271,7 +379,7 @@ class CodeGenerator(ast.NodeVisitor): # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) - self.module.seal_block(endif_bb) + self._seal_block(endif_bb) self.builder.set_insert_block(endif_bb) else: if isinstance(cond, triton.language.constexpr): @@ -296,9 +404,9 @@ class CodeGenerator(ast.NodeVisitor): assert len(node.ops) == 1 lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.core.constexpr): + if isinstance(lhs, triton.language.constexpr): lhs = lhs.value - if isinstance(rhs, triton.language.core.constexpr): + if isinstance(rhs, triton.language.constexpr): rhs = rhs.value if type(node.ops[0]) == ast.Is: return triton.language.constexpr(lhs is rhs) @@ -312,9 +420,9 @@ class CodeGenerator(ast.NodeVisitor): ast.Gt: '__gt__', ast.GtE: '__ge__', }[type(node.ops[0])] - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_object(rhs): + elif self.is_triton_tensor(rhs): fn = fn[:2] + 'r' + fn[2:] return getattr(rhs, fn)(lhs, _builder=self.builder) else: @@ -325,21 +433,21 @@ class CodeGenerator(ast.NodeVisitor): if type(node.op) == ast.Not: assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" return triton.language.constexpr(not op) - if isinstance(op, triton.language.core.constexpr): + if isinstance(op, triton.language.constexpr): op = op.value fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Invert: '__invert__', }[type(node.op)] - if self.is_triton_object(op): + if self.is_triton_tensor(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() def visit_While(self, node): current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): cond = self.visit(node.test) @@ -350,9 +458,9 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self._seal_block(stop_bb) + self._seal_block(loop_bb) + self._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -362,7 +470,7 @@ class CodeGenerator(ast.NodeVisitor): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) - if self.is_triton_object(lhs): + if self.is_triton_tensor(lhs): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] @@ -405,8 +513,8 @@ class CodeGenerator(ast.NodeVisitor): step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() - loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) - next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) + loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) + next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) def continue_fn(): self.visit(step_node) @@ -421,9 +529,9 @@ class CodeGenerator(ast.NodeVisitor): # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() - self.module.seal_block(stop_bb) - self.module.seal_block(loop_bb) - self.module.seal_block(next_bb) + self._seal_block(stop_bb) + self._seal_block(loop_bb) + self._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -451,7 +559,7 @@ class CodeGenerator(ast.NodeVisitor): args = [self.visit(arg) for arg in node.args] if isinstance(fn, JITFunction): return fn(*args, generator=self, **kws) - if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ + if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ sys.modules[fn.__module__] is triton.language.core: return fn(*args, _builder=self.builder, **kws) if fn in self.builtins.values(): @@ -581,7 +689,7 @@ class Kernel: } if hasattr(obj, 'data_ptr'): return type_names[obj.dtype] - if isinstance(obj, triton.language.core.constexpr): + if isinstance(obj, triton.language.constexpr): obj = obj.value if isinstance(obj, int): if -2**31 <= obj < 2**31: @@ -613,34 +721,34 @@ class Kernel: return 'scalar', name @staticmethod - def _to_triton_ir(context, obj): + def _to_triton_ir(obj): which, name = obj type_map = { - 'I': _triton.ir.type.get_int32, - 'L': _triton.ir.type.get_int64, - 'f': _triton.ir.type.get_fp32, - 'B': _triton.ir.type.get_int1, - 'f8': _triton.ir.type.get_fp8, - 'f16': _triton.ir.type.get_fp16, - 'bf16': _triton.ir.type.get_bf16, - 'f32': _triton.ir.type.get_fp32, - 'f64': _triton.ir.type.get_fp64, - 'i1': _triton.ir.type.get_int1, - 'i8': _triton.ir.type.get_int8, - 'i16': _triton.ir.type.get_int16, - 'i32': _triton.ir.type.get_int32, - 'i64': _triton.ir.type.get_int64, - 'u8': _triton.ir.type.get_uint8, - 'u16': _triton.ir.type.get_uint16, - 'u32': _triton.ir.type.get_uint32, - 'u64': _triton.ir.type.get_uint64, + 'I': triton.language.int32, + 'L': triton.language.int64, + 'f': triton.language.float32, + 'B': triton.language.int1, + 'f8': triton.language.float8, + 'f16': triton.language.float16, + 'bf16': triton.language.bfloat16, + 'f32': triton.language.float32, + 'f64': triton.language.float64, + 'i1': triton.language.int1, + 'i8': triton.language.int8, + 'i16': triton.language.int16, + 'i32': triton.language.int32, + 'i64': triton.language.int64, + 'u8': triton.language.uint8, + 'u16': triton.language.uint16, + 'u32': triton.language.uint32, + 'u64': triton.language.uint64, } # convert torch.Tensor to Triton IR pointers if which == 'ptr': - elt_ty = type_map[name](context) - return _triton.ir.type.make_ptr(elt_ty, 1) + elt_ty = type_map[name] + return triton.language.pointer_type(elt_ty, 1) # default path returns triton.ir.type directly - return type_map[name](context) + return type_map[name] @staticmethod def pow2_divisor(N): @@ -920,25 +1028,31 @@ class JITFunction: assert isinstance(tree.body[0], ast.FunctionDef) return tree + # Called by CodeGenerator.visit_Call() def __call__(self, *args, generator: CodeGenerator, **kwargs): try: from inspect import getcallargs arg_values = getcallargs(self.fn, *args, **kwargs) arg_values = [arg_values[name] for name in self.arg_names] - arg_values = [arg if isinstance(arg, triton.language.block) + arg_values = [arg if isinstance(arg, triton.language.tensor) else triton.language.constexpr(arg) for arg in arg_values] + # Record values in the caller (parent scope) gscope = generator.gscope.copy() lscope = generator.lscope.copy() - values = generator.module.get_values().copy() - types = generator.module.get_types().copy() + + # TODO: clear values other than args + lvalues = generator.lvalues.copy() + # types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope - generator.module.set_values(values) - generator.module.set_types(types) + + generator.lvalues = lvalues + # generator.module.set_types(types) + return ret except Exception as e: node = generator.last_node @@ -1023,9 +1137,9 @@ class JITFunction: # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] - ret_type = _triton.ir.type.get_void(context) - prototype = _triton.ir.type.make_function(ret_type, arg_types) + arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] + ret_type = triton.language.void + prototype = triton.language.function_type(ret_type, arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index df25e59fb..81b9fe790 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1,63 +1,36 @@ +from __future__ import annotations + +from enum import Enum from functools import wraps +from typing import List import triton -from triton._C.libtriton.triton import frontend, ir +from . import semantic +from triton._C.libtriton.triton import ir -# convert block/dtype to ir values -def _to_ir(x, builder): +def _to_tensor(x, builder): if isinstance(x, bool): - return builder.get_int1(x) + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values elif isinstance(x, int): if -2**31 <= x < 2**31: - return builder.get_int32(x) + return tensor(builder.get_int32(x), int32) elif 2**31 <= x < 2**32: - return builder.get_uint32(x) + return tensor(builder.get_uint32(x), uint32) elif -2**63 <= x < 2**63: - return builder.get_int64(x) + return tensor(builder.get_int64(x), int64) elif 2**63 <= x < 2**64: - return builder.get_uint64(x) + return tensor(builder.get_uint64(x), uint64) else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - return builder.get_float32(x) + return tensor(builder.get_float32(x), float32) elif isinstance(x, constexpr): - return _to_ir(x.value, builder) - elif isinstance(x, block): - return x.handle - elif isinstance(x, dtype): - return x.handle(builder) - return x - - -def _patch(fn): - def _from_ir(x): - if isinstance(x, ir.value): - if x.type.is_void(): - return None - return block(x) + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): return x - - def wrapper(*args, **kwargs): - builder = args[-1] - assert isinstance(builder, ir.builder) - args = [_to_ir(x, builder) for x in args] - # for i, arg in enumerate(args): - # if arg is None: - # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") - kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} - ret = fn(*args, **kwargs) - if isinstance(ret, tuple): - return map(_from_ir, ret) - return _from_ir(ret) - - return wrapper - - -for name in dir(frontend): - fn = getattr(frontend, name) - if callable(fn): - setattr(frontend, name, _patch(fn)) + assert False, f'cannot convert {x} to tensor' def builtin(fn): @@ -72,20 +45,147 @@ def builtin(fn): class dtype: - def __init__(self, init): - self.init = init + SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return self.name == 'fp8' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def is_void(self): + raise RuntimeError("Not implemented") + + def is_block(self): + return False + + def is_ptr(self): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name,)) @property - def name(self) -> str: - # The init functions are named something like 'get_int8'. Strip the prefix. - nom = self.init.__name__ - prefix = 'get_' - assert nom.startswith(prefix) - return nom[len(prefix):] + def scalar(self): + return self - def handle(self, builder): - ctx = builder.context - return self.init(ctx) + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8' or self.name == 'uint8': + return builder.get_int8_ty() + elif self.name == 'int16' or self.name == 'uint16': + return builder.get_int16_ty() + elif self.name == 'int32' or self.name == 'uint32': + return builder.get_int32_ty() + elif self.name == 'int64' or self.name == 'uint64': + return builder.get_int64_ty() + elif self.name == 'fp8': + return builder.get_fp8_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to covert {self} to ir type') def __str__(self): return self.name @@ -99,36 +199,112 @@ class dtype: return f'triton.language.{self.name}' -class pointer_dtype: - def __init__(self, element_ty): +class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') self.element_ty = element_ty + self.address_space = address_space - def handle(self, builder): - return ir.type.make_ptr(self.element_ty.handle(builder), 1) + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List[int]): + self.element_ty = element_ty + # FIXME: + # block_type's shape is a list of int + # while tensor's shape is a list of constexpr + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return f'<{self.shape}, {self.element_ty}>' + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: + self.ret_type = ret_type + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_type}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) + # scalar types -int1 = dtype(ir.type.get_int1) -int8 = dtype(ir.type.get_int8) -int16 = dtype(ir.type.get_int16) -int32 = dtype(ir.type.get_int32) -int64 = dtype(ir.type.get_int64) -uint8 = dtype(ir.type.get_uint8) -uint16 = dtype(ir.type.get_uint16) -uint32 = dtype(ir.type.get_uint32) -uint64 = dtype(ir.type.get_uint64) -float8 = dtype(ir.type.get_fp8) -float16 = dtype(ir.type.get_fp16) -bfloat16 = dtype(ir.type.get_bf16) -float32 = dtype(ir.type.get_fp32) -float64 = dtype(ir.type.get_fp64) +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8 = dtype('fp8') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') # pointer types -pi32_t = pointer_dtype(int32) +pi32_t = pointer_type(int32) # ----------------------- # constexpr @@ -149,7 +325,6 @@ class constexpr: def __repr__(self) -> str: return f"constexpr[{self.value}]" - # def __add__(self, other): return self.value + other.value @@ -219,31 +394,33 @@ class constexpr: return self.value(*args, **kwds) -class block: +class tensor: + # infer dtype from ir type @staticmethod - def _init_dtype(ir_type): + def _to_dtype(ir_type): + # block type + if ir_type.is_block(): + scalar_ty = tensor._to_dtype(ir_type.scalar) + return block_type(scalar_ty, ir_type.get_block_shapes()) + # pointer type + if ir_type.is_ptr(): + element_ty = tensor._to_dtype(ir_type.element) + return pointer_type(element_ty) # primitive type + if ir_type.is_void(): return void if ir_type.is_int1(): return int1 if ir_type.is_int8(): return int8 if ir_type.is_int16(): return int16 if ir_type.is_int32(): return int32 if ir_type.is_int64(): return int64 - if ir_type.is_uint8(): return uint8 - if ir_type.is_uint16(): return uint16 - if ir_type.is_uint32(): return uint32 - if ir_type.is_uint64(): return uint64 if ir_type.is_fp8(): return float8 if ir_type.is_fp16(): return float16 if ir_type.is_bf16(): return bfloat16 if ir_type.is_fp32(): return float32 if ir_type.is_fp64(): return float64 - # pointer type - if ir_type.is_ptr(): - element_ty = block._init_dtype(ir_type.element) - return pointer_dtype(element_ty) - raise ValueError(f"Unsupported type {ir_type}") + raise ValueError(f"Unsupported type {ir_type.repr()}") - def __init__(self, handle): + def __init__(self, handle, type: dtype): # IR handle self.handle = handle # Block shape @@ -254,9 +431,9 @@ class block: for s in self.shape: self.numel *= s self.numel = constexpr(self.numel) - # Data-type wrapper - self.dtype = block._init_dtype(self.handle.type.scalar) - # Shape is a constexpr + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: @@ -265,116 +442,139 @@ class block: @builtin def __add__(self, other, _builder=None): - return frontend.add(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - return frontend.sub(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) def __rsub__(self, other, _builder=None): - return frontend.sub(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) @builtin def __mul__(self, other, _builder=None): - return frontend.mul(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - return frontend.truediv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) def __rtruediv__(self, other, _builder=None): - return frontend.truediv(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - return frontend.floordiv(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) @builtin def __mod__(self, other, _builder=None): - return frontend.mod(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) # unary operators @builtin def __neg__(self, _builder=None): - return frontend.minus(self, _builder) + return semantic.minus(self, _builder) @builtin def __invert__(self, _builder=None): - return frontend.invert(self, _builder) + return semantic.invert(self, _builder) # bitwise operators @builtin def __and__(self, other, _builder=None): - return frontend.and_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) @builtin def __or__(self, other, _builder=None): - return frontend.or_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) @builtin def __xor__(self, other, _builder=None): - return frontend.xor_(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) @builtin def __lshift__(self, other, _builder=None): - return frontend.shl(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) @builtin def __rshift__(self, other, _builder=None): - return frontend.lshr(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.lshr(self, other, _builder) # comparison operators # > @builtin def __gt__(self, other, _builder=None): - return frontend.greater_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - return frontend.greater_than(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - return frontend.greater_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + @builtin def __rge__(self, other, _builder=None): - return frontend.greater_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - return frontend.less_than(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - return frontend.less_than(other, self, _builder) + return semantic.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - return frontend.less_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - return frontend.less_equal(other, self, _builder) + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - return frontend.equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) @builtin def __ne__(self, other, _builder=None): - return frontend.not_equal(self, other, _builder) + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) @builtin def __getitem__(self, slices, _builder=None): @@ -389,20 +589,25 @@ class block: elif sl == slice(None, None, None): dst_shape.append(src_shape[curr].value) curr += 1 - ret = frontend.reshape(self, dst_shape, _builder) + ret = semantic.reshape(self, dst_shape, _builder) return ret @builtin def to(self, dtype, bitcast=False, _builder=None): - dtype = dtype.handle(_builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value if bitcast: - return frontend.bitcast(self, dtype, _builder) - return frontend.cast(self, dtype, _builder) + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder) # ----------------------- # SPMD Programming Model # ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v @builtin @@ -414,13 +619,14 @@ def program_id(axis, _builder=None): :type axis: int """ # if axis == -1: - # pid0 = frontend.program_id(0, _builder) - # pid1 = frontend.program_id(1, _builder) - # pid2 = frontend.program_id(2, _builder) - # npg0 = frontend.num_programs(0, _builder) - # npg1 = frontend.num_programs(0, _builder) + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 - return frontend.program_id(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) @builtin @@ -431,7 +637,8 @@ def num_programs(axis, _builder=None): :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. :type axis: int """ - return frontend.num_programs(axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) # ----------------------- @@ -449,13 +656,15 @@ def arange(start, end, _builder=None): :param stop: End of the interval. Must be a power of two >= start. :type stop: int """ - return frontend.arange(start, end, _builder) + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) @builtin def zeros(shape, dtype, _builder=None): """ - Returns a block filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) :type shape: tuple of ints @@ -468,7 +677,8 @@ def zeros(shape, dtype, _builder=None): if not isinstance(d.value, int): raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") shape = [x.value for x in shape] - return frontend.zeros(shape, dtype, _builder) + dtype = _constexpr_to_value(dtype) + return semantic.zeros(shape, dtype, _builder) # ----------------------- @@ -481,25 +691,25 @@ def broadcast(input, other, _builder=None): """ Tries to broadcast the two given blocks to a common compatible shape. - :param input: The first input block. + :param input: The first input tensor. :type input: Block - :param other: The second input block. + :param other: The second input tensor. :type other: Block """ - return frontend.broadcast(input, other, _builder) + return semantic.broadcast_impl_value(input, other, _builder) @builtin def broadcast_to(input, shape, _builder=None): """ - Tries to broadcast the given block to a new :code:`shape`. + Tries to broadcast the given tensor to a new :code:`shape`. - :param input: The input block. + :param input: The input tensor. :type input: Block :param shape: The desired shape. :type shape: Tuple[int] """ - return frontend.broadcast_to(input, shape, _builder) + return semantic.broadcast_impl_shape(input, shape, _builder) @builtin @@ -507,27 +717,27 @@ def cat(input, other, _builder=None): """ Concatenate the given blocks - :param input: The first input block. + :param input: The first input tensor. :type input: - :param other: The second input block. + :param other: The second input tensor. :type other: """ - return frontend.cat(input, other, _builder) + return semantic.cat(input, other, _builder) @builtin def reshape(input, shape, _builder=None): """ - Tries to reshape the given block to a new shape. + Tries to reshape the given tensor to a new shape. - :param input: The input block. + :param input: The input tensor. :type input: :param shape: The desired shape. :type shape: Tuple[int] """ shape = [x.value for x in shape] - return frontend.reshape(input, shape, _builder) + return semantic.reshape(input, shape, _builder) # ----------------------- @@ -542,12 +752,13 @@ def dot(input, other, allow_tf32=True, _builder=None): The two blocks must be two dimensionals and have compatible inner dimensions. - :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} - :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ - return frontend.dot(input, other, allow_tf32, _builder) + allow_tf32 = _constexpr_to_value(allow_tf32) + return semantic.dot(input, other, allow_tf32, _builder) # ----------------------- @@ -558,7 +769,7 @@ def dot(input, other, allow_tf32=True, _builder=None): @builtin def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", volatile=False, _builder=None): """ - Return a block of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. + Return a tensor of data whose values are, elementwise, loaded from memory at location defined by :code:`pointer`. :code:`mask` and :code:`other` are implicitly broadcast to :code:`pointer.shape`. @@ -573,24 +784,36 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", :param cache_modifier: changes cache option in nvidia ptx 'type cache_modifier: str, optional """ - return frontend.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) + # mask, other can be constexpr + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, cache_modifier, eviction_policy, volatile, _builder) @builtin def store(pointer, value, mask=None, _builder=None): """ - Stores :code:`value` block of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. + Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. :code:`value` is implicitly broadcast to :code:`pointer.shape` and typecast to :code:`pointer.dtype.element_ty`. :param pointer: The memory locations where the elements of :code:`value` are stored. :type pointer: Block of dtype=triton.PointerDType - :param value: The block of elements to be stored. + :param value: The tensor of elements to be stored. :type value: Block :param mask: If mask[idx] is false, do not store :code:`value[idx]` at :code:`pointer[idx]`. :type mask: Block of triton.int1, optional """ - return frontend.store(pointer, value, mask, _builder) + # value can be constexpr + value = _to_tensor(value, _builder) + if mask is not None: + mask = _to_tensor(mask, _builder) + return semantic.store(pointer, value, mask, _builder) # ----------------------- @@ -621,49 +844,58 @@ def _add_atomic_docstr(name): @builtin @_add_atomic_docstr("compare-and-swap") def atomic_cas(pointer, cmp, val, _builder=None): - return frontend.atomic_cas(pointer, cmp, val, _builder) + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(cmp, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, _builder=None): - return frontend.atomic_xchg(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xchg(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, _builder=None): - return frontend.atomic_add(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_add(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, _builder=None): - return frontend.atomic_max(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_max(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, _builder=None): - return frontend.atomic_min(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_min(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, _builder=None): - return frontend.atomic_and(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_and(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, _builder=None): - return frontend.atomic_or(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_or(pointer, val, mask, _builder) @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, _builder=None): - return frontend.atomic_xor(pointer, val, mask, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_xor(pointer, val, mask, _builder) # ----------------------- @@ -674,7 +906,7 @@ def atomic_xor(pointer, val, mask=None, _builder=None): @builtin def where(condition, x, y, _builder=None): """ - Returns a block of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. @@ -688,7 +920,10 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - return frontend.where(condition, x, y, _builder) + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) # ----------------------- @@ -697,12 +932,15 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): - return frontend.umulhi(x, y, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.umulhi(x, y, _builder) @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): - return frontend.fdiv(x, y, ieee_rounding, _builder) + ieee_rounding = _constexpr_to_value(ieee_rounding) + return semantic.fdiv(x, y, ieee_rounding, _builder) def _add_math_1arg_docstr(name): @@ -723,31 +961,31 @@ def _add_math_1arg_docstr(name): @builtin @_add_math_1arg_docstr("exponential") def exp(x, _builder=None): - return frontend.exp(x, _builder) + return semantic.exp(x, _builder) @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): - return frontend.log(x, _builder) + return semantic.log(x, _builder) @builtin @_add_math_1arg_docstr("cosine") def cos(x, _builder=None): - return frontend.cos(x, _builder) + return semantic.cos(x, _builder) @builtin @_add_math_1arg_docstr("sine") def sin(x, _builder=None): - return frontend.sin(x, _builder) + return semantic.sin(x, _builder) @builtin @_add_math_1arg_docstr("square root") def sqrt(x, _builder=None): - return frontend.sqrt(x, _builder) + return semantic.sqrt(x, _builder) # ----------------------- @@ -758,7 +996,7 @@ def _add_reduction_docstr(name): def _decorator(func): docstr = """ - Returns the {name} of all elements in the :code:`input` block along the provided :code:`axis` + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` :param input: the input values :param axis: the dimension along which the reduction should be done @@ -772,25 +1010,29 @@ def _add_reduction_docstr(name): @builtin @_add_reduction_docstr("maximum") def max(input, axis, _builder=None): - return frontend.max(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.max(input, axis, _builder) @builtin @_add_reduction_docstr("minimum") def min(input, axis, _builder=None): - return frontend.min(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.min(input, axis, _builder) @builtin @_add_reduction_docstr("sum") def sum(input, axis, _builder=None): - return frontend.sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.sum(input, axis, _builder) @builtin @_add_reduction_docstr("xor sum") def xor_sum(input, axis, _builder=None): - return frontend.xor_sum(input, axis, _builder) + axis = _constexpr_to_value(axis) + return semantic.xor_sum(input, axis, _builder) # ----------------------- @@ -800,7 +1042,7 @@ def xor_sum(input, axis, _builder=None): @builtin def debug_barrier(_builder=None): - return frontend.debug_barrier(_builder) + return semantic.debug_barrier(_builder) @builtin @@ -808,7 +1050,8 @@ def multiple_of(input, value, _builder=None): """ Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. """ - return frontend.multiple_of(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.multiple_of(input, value) @builtin @@ -816,7 +1059,8 @@ def max_contiguous(input, value, _builder=None): """ Let the compiler knows that the `value` first values in :code:`input` are contiguous. """ - return frontend.max_contiguous(input, value, _builder) + value = _constexpr_to_value(value) + return semantic.max_contiguous(input, value) # ----------------------- @@ -846,9 +1090,9 @@ def minimum(x, y): """ Computes the element-wise minimum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x < y, x, y) @@ -859,9 +1103,9 @@ def maximum(x, y): """ Computes the element-wise maximum of :code:`x` and :code:`y`. - :param input: the first input block + :param input: the first input tensor :type input: Block - :param other: the second input block + :param other: the second input tensor :type other: Block """ return triton.language.where(x > y, x, y) @@ -887,7 +1131,7 @@ def ravel(x): """ Returns a contiguous flattened view of :code:`x` - :param x: the input block + :param x: the input tensor :type x: Block """ return triton.language.reshape(x, [x.numel]) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py new file mode 100644 index 000000000..4063b86fc --- /dev/null +++ b/python/triton/language/semantic.py @@ -0,0 +1,1037 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Tuple + +from . import core as tl +from triton._C.libtriton.triton import ir + + +# Create custom exception that prints message "hello" +class IncompatibleTypeErrorimpl(Exception): + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorimpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + assert False + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + assert False + # 4 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorimpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorimpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, + arithmetic_check=True, div_or_mod=False + ) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + assert False + + +def sub(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), + input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + assert False + + +def mul(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + assert False + + +def truediv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + assert False + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + assert False + + +def fdiv(input: tl.tensor, + other: tl.tensor, + ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise ValueError("both operands of fdiv must have floating poscalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + ret.set_fdiv_ieee_rounding(ieee_rounding) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + assert False + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def lshr(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, + builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, + builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + assert False + + +def greater_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + assert False + + +def equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + assert False + + +def not_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + assert False + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + shape = [end - start] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.get_range(start, end), ret_ty) + + +def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + _0 = ir.constant.get_null_value(dtype.to_ir(builder)) + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_splat(_0, shape), ret_ty) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def reshape(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("cannot reshape block of different shape") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + # TODO: check types + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type) + + +def broadcast_impl_shape(input: tl.tensor, + shape: List[int], + builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + if len(lhs_shape) != len(rhs_shape): + raise ValueError("Cannot make_shape_compatible: blocks must have the same rank") + ret_shape = [] + for i in range(len(lhs_shape)): + left = lhs_shape[i] + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif right == 1: + ret_shape.append(left) + elif left == right: + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + +def bitcast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), + dst_ty) + + +def cast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # bf16 <=> (not fp32) + if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \ + (dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # FP Truncation + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # FP Extension + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Int cast + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + return tl.tensor(builder.create_int_cast(input.handle, + dst_ty.to_ir(builder), sign_extend), + dst_ty) + + # Float to Int + if src_sca_ty.is_floating() and dst_sca_ty.is_int(): + # TODO: is this correct? + if dst_sca_ty.is_bool(): + return tl.tensor(builder.create_fp_to_ui(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_fp_to_si(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # int => float + if src_sca_ty.is_int() and dst_sca_ty.is_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # ptr => int + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_cast(ir.PtrToInt, input.handle, dst_ty.to_ir(builder)), + dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), + tl.tensor(builder.get_int64(0), tl.int64), + builder) + + if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + # Ptr . Ptr + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + # * . Bool + if dst_sca_ty.is_bool(): + if src_sca_ty.is_ptr(): + input = cast(input, tl.int64, builder) + other = builder.get_int64(0) + if src_ty.is_bool(): + other = builder.create_splat(other, src_ty.get_block_shapes()) + return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty) + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def load(ptr: tl.tensor, + mask: Optional[tl.tensor], + other: Optional[tl.tensor], + cache_modifier: str, + eviction_policy: str, + is_volatile: bool, + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + if other: + other = cast(other, ptr.type.scalar.element_ty, builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cache modifier + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + + # eviction policy + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + dst_ty = elt_ty + + if not mask and not other: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), + dst_ty) + if not mask: + raise ValueError("`other` cannot be provided without `mask`") + + if not other: + other_ir = ir.undef.get(elt_ty.to_ir(builder)) + if ptr.type.is_block(): + other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes()) + other = tl.tensor(other_ir, dst_ty) + + return tl.tensor(builder.create_masked_load(ptr.handle, + mask.handle, + other.handle, + cache, eviction, is_volatile), + dst_ty) + + +def store(ptr: tl.tensor, + val: tl.tensor, + mask: Optional[tl.tensor], + builder: ir.builder) -> tl.tensor: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # cast to target data-type + val = cast(val, elt_ty, builder) + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, + cmp: tl.tensor, + val: tl.tensor, + builder: ir.builder) -> tl.tensor: + # TODO: type checking + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_min(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + i_ptr.handle, + i_val.handle, + and_(mask, pos, builder).handle), + i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + i_ptr.handle, + i_val.handle, + and_(mask, neg, builder).handle), + i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_add(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_and(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_or(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xor(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xchg(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def dot(lhs: tl.tensor, + rhs: tl.tensor, + allow_tf32: bool, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + if lhs.type.scalar.is_int(): + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + else: + _0 = builder.get_float32(0) + ret_scalar_ty = tl.float32 + M = lhs.type.shape[0] + N = rhs.type.shape[1] + _0 = builder.create_splat(_0, [M, N]) + ret_ty = tl.block_type(ret_scalar_ty, [M, N]) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + +def where(condition: tl.tensor, + x: tl.tensor, + y: tl.tensor, + builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + x = broadcast_impl_shape(x, condition.type.get_block_shapes(), builder) + y = broadcast_impl_shape(y, condition.type.get_block_shapes(), builder) + + # TODO: we need to check x's and y's shape? + x_ty = x.type.scalar + y_ty = y.type.scalar + ty = computation_type_impl(x_ty, y_ty, div_or_mod=False) + x = cast(x, ty, builder) + y = cast(y, ty, builder) + if x.type.is_block(): + ret_ty = tl.block_type(ty, x.type.shape) + else: + ret_ty = ty + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reductions +# ===----------------------------------------------------------------------=== + +def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str, + FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor: + scalar_ty = input.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32: + input = cast(input, tl.int32, builder) + + # get result type + shape = input.type.shape + ret_shape = [] + for i, s in enumerate(shape): + if i != axis: + ret_shape.append(s) + if len(ret_shape) == 0: + res_ty = scalar_ty + else: + res_ty = tl.block_type(scalar_ty, ret_shape) + + if scalar_ty.is_floating(): + return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty) + elif scalar_ty.is_int(): + return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty) + assert False + + +def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN) + + +def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX) + + +def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD) + + +def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) + + +# ===----------------------------------------------------------------------=== +# Math +# ===----------------------------------------------------------------------=== + +def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + x, y = binary_op_type_checking_impl(x, y, builder) + return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type) + + +def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_exp(x.handle), x.type) + + +def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_log(x.handle), x.type) + + +def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_cos(x.handle), x.type) + + +def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sin(x.handle), x.type) + + +def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sqrt(x.handle), x.type) + + +## + +def multiple_of(x: tl.tensor, value: int) -> tl.tensor: + x.handle.multiple_of(value) + return x + + +def max_contiguous(x: tl.tensor, value: int) -> tl.tensor: + x.handle.max_contiguous(value) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(''), tl.void)