Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)

This reverts commit 539961072c.
This commit is contained in:
Philippe Tillet
2022-03-24 17:16:50 -07:00
committed by GitHub
parent ea6d1f1b85
commit 76a9ee50a8
19 changed files with 1670 additions and 2044 deletions

View File

@@ -38,8 +38,10 @@ public:
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}
// Constants // Constants
value *get_int1(bool val); value *get_int1(bool val);
value *get_int32(uint32_t val); value *get_int32(int32_t val);
value *get_int64(uint64_t val); value *get_int64(int64_t val);
value *get_uint32(uint32_t val);
value *get_uint64(uint64_t val);
value *get_float16(float val); value *get_float16(float val);
value *get_float32(float val); value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi); value *get_range(int32_t lo, int32_t hi);
@@ -50,9 +52,11 @@ public:
type *get_int16_ty(); type *get_int16_ty();
type *get_int32_ty(); type *get_int32_ty();
type *get_int64_ty(); type *get_int64_ty();
type *get_fp8_ty(); type *get_uint8_ty();
type *get_uint16_ty();
type *get_uint32_ty();
type *get_uint64_ty();
type *get_half_ty(); type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty(); type *get_float_ty();
type *get_double_ty(); type *get_double_ty();
// Insert // Insert
@@ -70,9 +74,7 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void(); value* create_ret_void();
// Cast instructions // Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty);
@@ -91,11 +93,11 @@ public:
value *create_frem(value *lhs, value *rhs); value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs); value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
@@ -143,22 +145,11 @@ public:
value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs); value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes); value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Built-in instruction // Built-in instruction
value *create_get_program_id(unsigned axis); value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis); value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg); value *create_exp(value* arg);
value *create_cos(value* arg); value *create_cos(value* arg);
value *create_sin(value* arg); value *create_sin(value* arg);

View File

@@ -26,6 +26,7 @@ public:
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types // integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types // Pointer types
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys; std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types // Block types

View File

@@ -0,0 +1,113 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
// math
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -57,10 +57,26 @@ private:
void push_function(function *fn) { functions_.push_back(fn); } void push_function(function *fn) { functions_.push_back(fn); }
public: public:
module(const std::string &name, builder &builder): name_(name), builder_(builder) {} module(const std::string &name, builder& builder);
builder &get_builder() { return builder_; }; builder& get_builder();
const std::string& get_name() { return name_; }; // Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
const std::map<std::string, type*>& get_types() { return types_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions // Functions
const functions_list_t &get_function_list() const { return functions_; } const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; } functions_list_t &get_function_list() { return functions_; }
@@ -73,14 +89,21 @@ public:
const std::map<std::string, ir::value*>& globals() const { return globals_; } const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata // Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
void print(std::ostream &os); void print(std::ostream &os);
private: private:
std::string name_; std::string name_;
builder &builder_; builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_; functions_list_t functions_;
symbols_map_t symbols_; symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::vector<ir::alloc_const*> allocs_; std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_; std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_; std::map<std::string, md_pair_t> metadatas_;

View File

@@ -16,6 +16,8 @@ class value;
class integer_type; class integer_type;
class constant_int; class constant_int;
enum class signedness { SIGNED, UNSIGNED };
/* Type */ /* Type */
class type { class type {
public: public:
@@ -59,6 +61,8 @@ public:
// type attributes // type attributes
unsigned get_fp_mantissa_width() const; unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const; unsigned get_integer_bitwidth() const;
signedness get_integer_signedness() const;
bool is_integer_signed() const;
unsigned get_tile_bitwidth() const; unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const; unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const; type *get_scalar_ty() const;
@@ -81,6 +85,9 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; } bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth, signedness sn) {
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
}
bool is_bool_ty() const { return is_integer_ty(1); } bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; } bool is_block_ty() const { return id_ == BlockTyID; }
@@ -108,6 +115,10 @@ public:
static integer_type *get_int32_ty(context &ctx); static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx); static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx); static integer_type *get_int128_ty(context &ctx);
static integer_type *get_uint8_ty(context &ctx);
static integer_type *get_uint16_ty(context &ctx);
static integer_type *get_uint32_ty(context &ctx);
static integer_type *get_uint64_ty(context &ctx);
// repr // repr
std::string tile_repr() const { std::string tile_repr() const {
@@ -134,7 +145,7 @@ public:
case LabelTyID: return "label"; case LabelTyID: return "label";
case MetadataTyID: return "md"; case MetadataTyID: return "md";
case TokenTyID: return "tok"; case TokenTyID: return "tok";
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth()); case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn"; case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct"; case StructTyID: return "struct";
@@ -157,18 +168,21 @@ class integer_type: public type {
private: private:
// constructors // constructors
integer_type(context &ctx, unsigned bitwidth) integer_type(context &ctx, unsigned bitwidth, signedness sn)
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {} : type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
public: public:
// accessors // accessors
unsigned get_bitwidth() const { return bitwidth_; } unsigned get_bitwidth() const { return bitwidth_; }
signedness get_signedness() const { return signedness_; }
// factory methods // factory methods
static integer_type* get(context &ctx, unsigned width); static integer_type* get(context &ctx, unsigned width);
private: private:
unsigned bitwidth_; unsigned bitwidth_;
signedness signedness_;
}; };
class composite_type: public type{ class composite_type: public type{

View File

@@ -48,12 +48,18 @@ void builder::set_insert_point(basic_block *block){
value *builder::get_int1(bool val) value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); } { return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(uint32_t val) value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);} { return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_int64(uint64_t val) value *builder::get_uint32(uint32_t val)
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);} { return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_uint64(uint64_t val)
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
value *builder::get_float16(float val) value *builder::get_float16(float val)
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); } { return constant_fp::get(type::get_fp16_ty(ctx_), val); }
@@ -84,15 +90,21 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty() type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); } { return type::get_int64_ty(ctx_); }
type *builder::get_fp8_ty() type *builder::get_uint8_ty()
{ return type::get_fp8_ty(ctx_); } { return type::get_uint8_ty(ctx_); }
type *builder::get_uint16_ty()
{ return type::get_uint16_ty(ctx_); }
type *builder::get_uint32_ty()
{ return type::get_uint32_ty(ctx_); }
type *builder::get_uint64_ty()
{ return type::get_uint64_ty(ctx_); }
type *builder::get_half_ty() type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); } { return type::get_fp16_ty(ctx_); }
type *builder::get_bf16_ty()
{ return type::get_bf16_ty(ctx_); }
type *builder::get_float_ty() type *builder::get_float_ty()
{ return type::get_fp32_ty(ctx_); } { return type::get_fp32_ty(ctx_); }
@@ -127,8 +139,6 @@ value *builder::create_ret_void() {
return create_cast(OPCODE, src, dst_ty);\ return create_cast(OPCODE, src, dst_ty);\
} }
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
@@ -321,28 +331,6 @@ value *builder::create_downcast(value *arg) {
return insert(downcast_inst::create(arg)); return insert(downcast_inst::create(arg));
} }
//
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
return create_atomic_rmw(OPCODE, ptr, val, mask);\
}
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// built-in instructions // built-in instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -359,6 +347,9 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
return insert(atomic_cas_inst::create(ptr, cmp, val)); return insert(atomic_cas_inst::create(ptr, cmp, val));
} }
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
value *builder::create_exp(value *arg){ value *builder::create_exp(value *arg){
return insert(exp_inst::create(arg)); return insert(exp_inst::create(arg));

View File

@@ -19,12 +19,18 @@ context_impl::context_impl(context &ctx)
fp32_ty(ctx, type::FP32TyID), fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID), fp64_ty(ctx, type::FP64TyID),
// integers // integers
int1_ty(ctx, 1), int1_ty(ctx, 1, signedness::SIGNED),
int8_ty(ctx, 8), int8_ty(ctx, 8, signedness::SIGNED),
int16_ty(ctx, 16), int16_ty(ctx, 16, signedness::SIGNED),
int32_ty(ctx, 32), int32_ty(ctx, 32, signedness::SIGNED),
int64_ty(ctx, 64), int64_ty(ctx, 64, signedness::SIGNED),
int128_ty(ctx, 128) {} int128_ty(ctx, 128, signedness::SIGNED),
uint8_ty(ctx, 8, signedness::UNSIGNED),
uint16_ty(ctx, 16, signedness::UNSIGNED),
uint32_ty(ctx, 32, signedness::UNSIGNED),
uint64_ty(ctx, 64, signedness::UNSIGNED){
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// context // context

882
lib/ir/dispatch.cc Normal file
View File

@@ -0,0 +1,882 @@
#include "triton/ir/dispatch.h"
namespace triton {
namespace ir {
[[ noreturn ]] void throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
}
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
auto a_sn = a_ty->get_integer_signedness();
auto b_sn = b_ty->get_integer_signedness();
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if (a_sn == b_sn) {
return a_rank > b_rank ? a_ty : b_ty;
} else if (a_sn == signedness::UNSIGNED) {
return a_rank >= b_rank ? a_ty : b_ty;
} else if (b_sn == signedness::UNSIGNED) {
return b_rank >= a_rank ? b_ty : a_ty;
} else {
throw_unreachable("integer_promote");
}
}
enum class DivOrMod { NO, YES };
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly converted to half
// unless we're doing / or %, which do not exist natively in PTX for fp16.
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
if (div_or_mod == DivOrMod::YES) {
return type::get_fp32_ty(ctx);
} else {
return type::get_fp16_ty(ctx);
}
}
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("computation_type");
// 4 ) both operands are integer and undergo
// integer promotion
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
return integer_promote(a_ty, b_ty);
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
}
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
if(type_a->is_pointer_ty()){
if(!allow_ptr_a)
throw_incompatible_types(type_a, type_b);
// T* + U* with T != U
if(type_b->is_pointer_ty() && (type_a != type_b))
throw_incompatible_types(type_a, type_b);
// T* + float
if(type_b->is_floating_point_ty())
throw_incompatible_types(type_a, type_b);
}
}
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
// implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
}
}
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, true);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// offset + ptr
// ptr + offset
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
std::swap(input, other);
if (input_scalar_ty->is_pointer_ty())
return builder->create_gep(input, {other});
// float + float
else if (input_scalar_ty->is_floating_point_ty())
return builder->create_fadd(input, other);
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, false);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// ptr - offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(input, {dispatch::minus(other, builder)});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(input, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(input, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
other = cast(other, input_scalar_ty, builder);
// int / float
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
input = cast(input, other_scalar_ty, builder);
// int / int (cast to float32)
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
input = cast(input, builder->get_float_ty(), builder);
other = cast(other, builder->get_float_ty(), builder);
}
// float / float (cast to highest exponent type)
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
other = cast(other, input_scalar_ty, builder);
else
input = cast(input, other_scalar_ty, builder);
}
// unreachable
else
throw_unreachable("div");
return builder->create_fdiv(input, other);
}
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
if (ret_ty->is_integer_signed()) {
return builder->create_sdiv(input, other);
} else {
return builder->create_udiv(input, other);
}
}
throw_unreachable("floordiv");
}
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
throw semantic_error("both operands of fdiv must have floating point scalar type");
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
ir::value* ret = builder->create_fdiv(input, other);
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
return ret;
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
if (scalar_ty->is_integer_signed()) {
return builder->create_srem(input, other);
} else {
return builder->create_urem(input, other);
}
}
throw_unreachable("mod");
}
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
if (ret_sca_ty != input_sca_ty)
input = dispatch::cast(input, ret_sca_ty, builder);
if (ret_sca_ty != other_sca_ty)
other = dispatch::cast(other, ret_sca_ty, builder);
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_and(input, other);
}
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_or(input, other);
}
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_xor(input, other);
}
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_lshr(input, other);
}
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_shl(input, other);
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
return input;
}
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty())
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
return dispatch::sub(_0, input, builder);
}
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
return dispatch::xor_(input, _1, builder);
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGT(input, other);
} else {
return builder->create_icmpUGT(input, other);
}
}
throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGE(input, other);
} else {
return builder->create_icmpUGE(input, other);
}
}
throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLT(input, other);
} else {
return builder->create_icmpULT(input, other);
}
}
throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLE(input, other);
} else {
return builder->create_icmpULE(input, other);
}
}
throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpUNE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
}
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
unsigned numel = 1;
for(unsigned s: dst_shape) numel *= s;
if(input->get_type()->get_tile_num_elements() != numel)
throw semantic_error("cannot reshape block of different shape");
return builder->create_reshape(input, dst_shape);
}
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
return builder->create_cat(lhs, rhs);
}
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
if(shape == src_shape)
return input;
return builder->create_broadcast(input, shape);
}
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
return cast(input, dst_ty, builder);
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
if( src_bits!= dst_bits)
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
"to data-type of size " + std::to_string(dst_bits));
return builder->create_cast(ir::BitCast, input, dst_ty);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
//
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
}
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
return builder->create_int_cast(input, dst_ty, sign_extend);
}
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
return builder->create_fp_to_ui(input, dst_ty);
else
return builder->create_fp_to_si(input, dst_ty);
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
}
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){
int bitwidth = dst_sca_ty->get_integer_bitwidth();
if(bitwidth == 64)
return builder->create_cast(ir::PtrToInt, input, dst_ty);
if(bitwidth == 1)
return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder),
builder->get_int64(0),
builder);
}
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::IntToPtr, input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, builder->get_int64_ty(), builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
if(other)
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
}
if(other)
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cache modifier
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
// eviction policy
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
if(!eviction_policy.empty()){
if (eviction_policy == "evict_last")
eviction = load_inst::EVICT_LAST;
else if(eviction_policy == "evict_first")
eviction = load_inst::EVICT_FIRST;
else
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr, cache, eviction, is_volatile);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty())
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
val = dispatch::cast(val, elt_ty, builder);
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
throw semantic_error("Mask must have boolean scalar type");
return builder->create_masked_store(ptr, val, mask);
}
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
return builder->create_atomic_cas(ptr, cmp, val);
}
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(val){
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
}
}
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
}
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
}
}
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_min for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
}
}
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
return builder->create_atomic_rmw(op, ptr, val, mask);
}
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
}
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
}
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
}
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = nullptr;
if (lhs->get_type()->is_int_or_tileint_ty())
_0 = builder->get_int32(0);
else
_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
bool _allow_tf32 = allow_tf32->get_value() != 0;
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
if(condition->get_type()->is_block_ty()){
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
ir::type* x_ty = x->get_type()->get_scalar_ty();
ir::type* y_ty = y->get_type()->get_scalar_ty();
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
x = dispatch::cast(x, ty, builder);
y = dispatch::cast(y, ty, builder);
return builder->create_select(condition, x, y);
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
}
ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (!scalar_ty->is_integer_ty())
throw semantic_error("xor_sum only supported for integers");
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
binary_op_type_checking(x, y, builder);
return builder->insert(umulhi_inst::create(x, y));
}
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
return builder->create_exp(x);
}
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
return builder->create_log(x);
}
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
return builder->create_cos(x);
}
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
return builder->create_sin(x);
}
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("multiple_of");
i->set_metadata(ir::metadata::multiple_of, value);
return i;
}
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("max_contiguous");
i->set_metadata(ir::metadata::max_contiguous, value);
return i;
}
ir::value *dispatch::debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
}
}

View File

@@ -312,8 +312,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth(); unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth(); unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast : cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
(arg_bits > dst_bits ? cast_op_t::Trunc : (arg_bits > dst_bits ? cast_op_t::Trunc :
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt))); (is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
return create(op, arg, ty, name, next); return create(op, arg, ty, name, next);
} }

View File

@@ -9,6 +9,146 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
/* Module */
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
sealed_blocks_.insert(nullptr);
}
ir::builder& module::get_builder() {
return builder_;
}
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
// value->set_name(name);
}
void module::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
void module::set_const(const std::string& name){
const_.insert(name);
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
builder_.set_insert_point(insert);
}
ir::phi_node *res = builder_.create_phi(ty, num_values);
if(insert != block->end())
builder_.set_insert_point(block);
return res;
}
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
[phi](ir::value* op){ return op != phi && op; });
// non-trivial
if(non_self_ref.size() != 1)
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::set<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
return same;
}
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
return phi;
ir::basic_block *block = phi->get_parent();
for(ir::basic_block *pred: block->get_predecessors()){
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return phi;
}
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = types_.at(name);
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
else if(preds.size() <= 1){
bool has_pred = preds.size();
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
if(values_.find(key) != values_.end()){
return values_.at(key);
}
ir::value *result = get_value_recursive(name, block);
builder_.set_insert_point(save_block);
if(save_pt != save_block->end())
builder_.set_insert_point(save_pt);
return result;
}
ir::value *module::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
const std::string& module::get_name() {
return name_;
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
if(get_value(x.first) == x.second)
set_value(x.first, try_remove_trivial_phis(x.second));
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
}
/* functions */ /* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name]; function *&fn = (function*&)symbols_[name];

View File

@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
unsigned type::get_integer_bitwidth() const unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
signedness type::get_integer_signedness() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
bool type::is_integer_signed() const {
if (id_ != IntegerTyID) {
throw std::logic_error("type is " + repr() + ", not integer");
}
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
}
unsigned type::get_tile_bitwidth() const unsigned type::get_tile_bitwidth() const
{ return ((block_type*)(this))->get_bitwidth(); } { return ((block_type*)(this))->get_bitwidth(); }
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }

View File

@@ -3,6 +3,7 @@
#include "triton/driver/error.h" #include "triton/driver/error.h"
#include "triton/driver/llvm.h" #include "triton/driver/llvm.h"
#include "triton/ir/builder.h" #include "triton/ir/builder.h"
#include "triton/ir/dispatch.h"
#include "triton/ir/enums.h" #include "triton/ir/enums.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
@@ -11,12 +12,10 @@
#include <pybind11/buffer_info.h> #include <pybind11/buffer_info.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "Python.h" #include "Python.h"
#include <regex> #include <regex>
#include <sstream> #include <sstream>
#include <stdexcept>
#include <string> #include <string>
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
@@ -542,6 +541,84 @@ void init_triton_codegen(py::module &&m) {
}, py::return_value_policy::take_ownership); }, py::return_value_policy::take_ownership);
} }
/*****************************************************************************/
/* User-facing language features */
/*****************************************************************************/
void init_triton_frontend(py::module &&m) {
using ret = py::return_value_policy;
// programming model
m.def("program_id", &ir::dispatch::program_id, ret::reference);
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
// binary
m.def("add", &ir::dispatch::add, ret::reference);
m.def("sub", &ir::dispatch::sub, ret::reference);
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
m.def("xor_", &ir::dispatch::xor_, ret::reference);
m.def("lshr", &ir::dispatch::lshr, ret::reference);
m.def("shl", &ir::dispatch::shl, ret::reference);
// unary
m.def("plus", &ir::dispatch::plus, ret::reference);
m.def("minus", &ir::dispatch::minus, ret::reference);
m.def("invert", &ir::dispatch::invert, ret::reference);
// comparison
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
m.def("less_than", &ir::dispatch::less_than, ret::reference);
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
m.def("equal", &ir::dispatch::equal, ret::reference);
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
// block creation
m.def("arange", &ir::dispatch::arange, ret::reference);
m.def("zeros", &ir::dispatch::zeros, ret::reference);
// type manipuatation
m.def("cat", &ir::dispatch::cat, ret::reference);
m.def("reshape", &ir::dispatch::reshape, ret::reference);
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing
m.def("where", &ir::dispatch::where, ret::reference);
// reduction
m.def("min", &ir::dispatch::min, ret::reference);
m.def("max", &ir::dispatch::max, ret::reference);
m.def("sum", &ir::dispatch::sum, ret::reference);
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
// math
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
m.def("exp", &ir::dispatch::exp, ret::reference);
m.def("log", &ir::dispatch::log, ret::reference);
m.def("cos", &ir::dispatch::cos, ret::reference);
m.def("sin", &ir::dispatch::sin, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
}
/*****************************************************************************/ /*****************************************************************************/
/* Python bindings for triton::ir */ /* Python bindings for triton::ir */
@@ -551,86 +628,16 @@ void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy; using ret = py::return_value_policy;
using namespace pybind11::literals; using namespace pybind11::literals;
py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
.value("NONE", ir::load_inst::NONE)
.value("CA", ir::load_inst::CA)
.value("CG", ir::load_inst::CG)
.export_values();
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
.value("NORMAL", ir::load_inst::NORMAL)
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
.export_values();
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
.value("ADD", ir::reduce_inst::ADD)
.value("FADD", ir::reduce_inst::FADD)
.value("MIN", ir::reduce_inst::MIN)
.value("MAX", ir::reduce_inst::MAX)
.value("FMIN", ir::reduce_inst::FMIN)
.value("FMAX", ir::reduce_inst::FMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
.value("ADD", ir::atomic_rmw_op_t::Add)
.value("FADD", ir::atomic_rmw_op_t::FAdd)
.value("AND", ir::atomic_rmw_op_t::And)
.value("OR", ir::atomic_rmw_op_t::Or)
.value("XOR", ir::atomic_rmw_op_t::Xor)
.value("XCHG", ir::atomic_rmw_op_t::Xchg)
.value("MAX", ir::atomic_rmw_op_t::Max)
.value("MIN", ir::atomic_rmw_op_t::Min)
.value("UMIN", ir::atomic_rmw_op_t::UMin)
.value("UMAX", ir::atomic_rmw_op_t::UMax);
py::class_<ir::context>(m, "context") py::class_<ir::context>(m, "context")
.def(py::init<>()); .def(py::init<>());
py::class_<ir::value>(m, "value") auto value = py::class_<ir::value>(m, "value");
.def("multiple_of", [](ir::value *self, int val) { value.def_property("name", &ir::value::get_name, &ir::value::set_name);
if (auto *instr = dynamic_cast<ir::instruction*>(self)) { value.def_property_readonly("type", &ir::value::get_type);
instr->set_metadata(ir::metadata::multiple_of, val);
} else
throw std::runtime_error("multiple_of");
})
.def("max_contiguous", [](ir::value *self, int val) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
instr->set_metadata(ir::metadata::max_contiguous, val);
} else
throw std::runtime_error("max_contiguous");
})
.def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
instr->set_fdiv_ieee_rounding(val);
else
throw std::runtime_error("set_fdiv_ieee_rounding");
})
.def("is_phi", [](ir::value *self) {
if (auto *pn = dynamic_cast<ir::phi_node*>(self))
return true;
return false;
})
.def("ops", [](ir::value *self) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
return instr->ops();
}
throw std::runtime_error("cannot use ops()");
})
.def("replace_all_uses_with", &ir::value::replace_all_uses_with)
.def("erase_from_parent", [](ir::value *self) {
if (auto *instr = dynamic_cast<ir::instruction*>(self))
return instr->erase_from_parent();
throw std::runtime_error("cannot use erase_from_parent");
})
.def_property("name", &ir::value::get_name, &ir::value::set_name)
.def_property_readonly("type", &ir::value::get_type);
py::class_<ir::user, ir::value>(m, "user"); py::class_<ir::user, ir::value>(m, "user");
py::class_<ir::constant, ir::user>(m, "constant") py::class_<ir::constant, ir::user>(m, "constant");
.def("get_null_value", &ir::constant::get_null_value, ret::reference)
.def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
py::class_<ir::undef_value, ir::constant>(m, "undef") py::class_<ir::undef_value, ir::constant>(m, "undef")
.def("get", &ir::undef_value::get, ret::reference); .def("get", &ir::undef_value::get, ret::reference);
@@ -641,17 +648,16 @@ void init_triton_ir(py::module &&m) {
.def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); .def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
py::class_<ir::constant_fp, ir::constant>(m, "constant_float") py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value) .def_property_readonly("value", &ir::constant_fp::get_value);
.def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
py::class_<ir::instruction, ir::user>(m, "instruction") py::class_<ir::instruction, ir::user>(m, "instruction");
.def("get_parent", [](ir::instruction *self) { py::class_<ir::phi_node, ir::user>(m, "phi_node");
return self->get_parent();
}, ret::reference);
py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
.def("add_incoming", &ir::phi_node::add_incoming);
py::class_<ir::type>(m, "type") py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_ptr", &ir::pointer_type::get, ret::reference)
.def("make_function", &ir::function_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference)
@@ -666,38 +672,34 @@ void init_triton_ir(py::module &&m) {
.def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference) .def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
.def("get_block_shapes", &ir::type::get_block_shapes)
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("is_void", &ir::type::is_void_ty) .def("is_void", &ir::type::is_void_ty)
.def("is_bool", &ir::type::is_bool_ty)
.def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp8", &ir::type::is_fp8_ty)
.def("is_fp16", &ir::type::is_fp16_ty) .def("is_fp16", &ir::type::is_fp16_ty)
.def("is_bf16", &ir::type::is_bf16_ty) .def("is_bf16", &ir::type::is_bf16_ty)
.def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp32", &ir::type::is_fp32_ty)
.def("is_fp64", &ir::type::is_fp64_ty) .def("is_fp64", &ir::type::is_fp64_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); }) .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); }) .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); }) .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); }) .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); }) .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
.def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty) .def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("repr", &ir::type::repr) .def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty) .def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference) .def_property_readonly("context", &ir::type::get_context, ret::reference);
.def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
.def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
py::class_<ir::pointer_type, ir::type>(m, "pointer_type") py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference) .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type"); py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::integer_type, ir::type>(m, "integer_type"); py::class_<ir::integer_type, ir::type>(m, "integer_type");
@@ -707,15 +709,16 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::module>(m, "module") py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>()) .def(py::init<std::string, ir::builder &>())
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
const auto metadatas = self->get_metadatas(); .def("seal_block", &ir::module::seal_block)
auto it = metadatas.find(name); .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
if (it != metadatas.end()) .def("set_type", &ir::module::set_type)
if (auto *instr = dynamic_cast<ir::instruction*>(value)) { .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
instr->set_metadata(it->second.first, it->second.second); .def("get_values", &ir::module::get_values, ret::reference)
} .def("set_values", &ir::module::set_values)
}) .def("get_types", &ir::module::get_types, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); .def("set_types", &ir::module::set_types)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t; using eattr = ir::attribute_kind_t;
py::enum_<eattr>(m, "attribute_kind") py::enum_<eattr>(m, "attribute_kind")
@@ -739,13 +742,6 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::basic_block, ir::value>(m, "basic_block") py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference) .def("create", &ir::basic_block::create, ret::reference)
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
ir::basic_block::iterator it = self->get_first_non_phi();
if (it == self->end())
return nullptr;
return *it;
}, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder>(m, "builder", py::dynamic_attr()) py::class_<ir::builder>(m, "builder", py::dynamic_attr())
@@ -756,162 +752,17 @@ void init_triton_ir(py::module &&m) {
.def("br", &ir::builder::create_br, ret::reference) .def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference)
// insertion block/point, insert points are represented as (*bb, *instr)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
.def("get_insert_point", [](ir::builder *self) { // constants
ir::basic_block *bb = self->get_insert_block();
ir::basic_block::iterator it = self->get_insert_point();
ir::instruction *instr = it == bb->end() ? nullptr : *it;
return std::make_pair(bb, instr);
}, ret::reference)
.def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
ir::basic_block *bb = pt.first;
ir::instruction *instr = pt.second;
if (instr) {
if (bb != instr->get_parent())
throw std::runtime_error("invalid insertion point, instr not in bb");
self->set_insert_point(instr);
} else {
assert(bb);
self->set_insert_point(bb);
}
})
// Constants
.def("get_int1", &ir::builder::get_int1, ret::reference) .def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) .def("get_int32", &ir::builder::get_int32, ret::reference)
.def("get_uint32", &ir::builder::get_int32, ret::reference) .def("get_int64", &ir::builder::get_int64, ret::reference)
.def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) .def("get_uint32", &ir::builder::get_uint32, ret::reference)
.def("get_uint64", &ir::builder::get_int64, ret::reference) .def("get_uint64", &ir::builder::get_uint64, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference) .def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference) .def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference) .def("get_range", &ir::builder::get_range, ret::reference);
// Types
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
.def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
.def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
.def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
// terminator instructions
.def("create_br", &ir::builder::create_br, ret::reference)
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
// Cast instructions
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
.def("create_cast", &ir::builder::create_cast, ret::reference)
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
.def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
.def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
.def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
.def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
.def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
// phi
.def("create_phi", &ir::builder::create_phi, ret::reference)
// Binary instructions
.def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
.def("create_fmul", &ir::builder::create_fmul, ret::reference)
.def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
.def("create_frem", &ir::builder::create_frem, ret::reference)
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
.def("create_srem", &ir::builder::create_srem, ret::reference)
.def("create_urem", &ir::builder::create_urem, ret::reference)
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sub", &ir::builder::create_sub, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_shl", &ir::builder::create_shl, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
// GEP
.def("create_gep", &ir::builder::create_gep, ret::reference)
// Comparison (int)
.def("create_icmp", &ir::builder::create_icmp, ret::reference)
.def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
.def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
.def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
.def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
.def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
.def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
.def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
.def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
.def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
.def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
// Comparison (float)
.def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
.def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
.def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
.def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
.def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
.def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
.def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
.def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
.def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
.def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
.def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
.def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
.def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
// Logical
.def("create_and", &ir::builder::create_and, ret::reference)
.def("create_xor", &ir::builder::create_xor, ret::reference)
.def("create_or", &ir::builder::create_or, ret::reference)
// Input/Output
.def("create_load", &ir::builder::create_load, ret::reference)
.def("create_store", &ir::builder::create_store, ret::reference)
.def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
.def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
// Block instruction
.def("create_splat", &ir::builder::create_splat, ret::reference)
.def("create_reshape", &ir::builder::create_reshape, ret::reference)
.def("create_cat", &ir::builder::create_cat, ret::reference)
.def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
// atomic
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
// Built-in instruction
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
.def("create_exp", &ir::builder::create_exp, ret::reference)
.def("create_cos", &ir::builder::create_cos, ret::reference)
.def("create_sin", &ir::builder::create_sin, ret::reference)
.def("create_log", &ir::builder::create_log, ret::reference)
.def("create_dot", &ir::builder::create_dot, ret::reference)
.def("create_trans", &ir::builder::create_trans, ret::reference)
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
.def("create_select", &ir::builder::create_select, ret::reference)
// Intrinsics
// These have no place in the IR, and hopefully they can be removed at some point
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
} }
void init_triton(py::module &m) { void init_triton(py::module &m) {
@@ -919,4 +770,5 @@ void init_triton(py::module &m) {
init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir"))); init_triton_ir(std::move(subm.def_submodule("ir")));
init_triton_frontend(std::move(subm.def_submodule("frontend")));
} }

View File

@@ -37,7 +37,7 @@ matmul_data = {
(256, 256, 256): {'float16': 0.027}, (256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158}, (512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466}, (1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.695}, (2048, 2048, 2048): {'float16': 0.680},
(4096, 4096, 4096): {'float16': 0.831}, (4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849}, (8192, 8192, 8192): {'float16': 0.849},
# tall-skinny # tall-skinny

View File

@@ -1,4 +1,5 @@
# flake8: noqa: F821,F841 # flake8: noqa: F821,F841
import copy
import itertools import itertools
import re import re
from typing import Optional, Union from typing import Optional, Union
@@ -584,6 +585,7 @@ def test_f8_f16_roundtrip():
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
f8_output = triton.reinterpret(f8_output_tensor, tl.float8) f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
print(f16.dtype, f8_output.dtype)
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
assert torch.all(f8_tensor == f8_output_tensor) assert torch.all(f8_tensor == f8_output_tensor)
@@ -991,6 +993,27 @@ def test_noop(device='cuda'):
kernel[(1, )](x) kernel[(1, )](x)
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR.
triton_ir = pgm.asm['ttir']
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
assert ir_value_type == value_type
@pytest.mark.parametrize( @pytest.mark.parametrize(
"value, overflow", "value, overflow",
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]

View File

@@ -1,5 +1,4 @@
import os import os
import re
import shutil import shutil
import pytest import pytest
@@ -103,30 +102,3 @@ def test_specialize(mode):
for i in [1, 2, 4, 8, 16, 32]: for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512) function[(1,)](x, i, BLOCK=512)
assert counter == target assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type

View File

@@ -6,8 +6,7 @@ __version__ = '2.0.0'
# or pybind11 shows `munmap_chunk(): invalid pointer` # or pybind11 shows `munmap_chunk(): invalid pointer`
import torch import torch
# submodules # submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
JITFunction, Config, Autotuner, reinterpret
from . import language from . import language
from . import code_gen from . import code_gen
from . import testing from . import testing

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import ast import ast
import builtins import builtins
import functools import functools
@@ -13,7 +11,7 @@ import tempfile
import textwrap import textwrap
import time import time
import warnings import warnings
from typing import Dict, Optional, Set, Tuple, Union from typing import Dict
import torch import torch
from filelock import FileLock from filelock import FileLock
@@ -24,13 +22,48 @@ from .tools.disasm import extract
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
def get_value(self, name):
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.language.block):
handle = self.module.get_value(name)
return triton.language.block(handle)
return ret
def set_value(self, name, value):
if isinstance(value, _triton.ir.value):
value = triton.language.block(value)
if isinstance(value, triton.language.block):
self.module.set_value(name, value.handle)
self.module.set_type(name, value.handle.type)
self.lscope[name] = value
def is_triton_object(self, value):
return isinstance(value, triton.language.block)
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
def __init__(self, context, prototype, gscope, attributes, constants, kwargs): def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context) self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder) self.module = _triton.ir.module('', self.builder)
self.prototype = prototype self.prototype = prototype
self.gscope = gscope self.gscope = gscope
self.lscope = dict() self.lscope = dict()
self.is_arg_lscope = dict() # name => is_arg: {str: bool}
self.attributes = attributes self.attributes = attributes
self.constants = constants self.constants = constants
self.kwargs = kwargs self.kwargs = kwargs
@@ -44,146 +77,6 @@ class CodeGenerator(ast.NodeVisitor):
'isinstance': isinstance, 'isinstance': isinstance,
'getattr': getattr, 'getattr': getattr,
} }
# SSA-construction
# [name, bb] => triton.language.tensor
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
# bb => {name => phi}
self.incomplete_phis = {}
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
def get_value(self, name):
''' This function:
1. make sure `name` is defined
2. if `name` is triton.language.tensor, get stored tensor by calling
`self._get_tensor()`
'''
# search node.id in local scope
ret = None
if name in self.lscope:
ret = self.lscope[name]
# search node.id in global scope
elif name in self.gscope:
ret = self.gscope[name]
# search node.id in builtins
elif name in self.builtins:
ret = self.builtins[name]
else:
raise ValueError(f'{name} is not defined')
if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]:
return self._get_tensor(name)
return ret
def set_value(self, name: str,
value: Union[triton.language.tensor, triton.language.constexpr],
is_arg: bool = False) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
self.lscope[name] = value
# if this value is an argument, we don't need to create phis for it
self.is_arg_lscope[name] = is_arg
if isinstance(value, triton.language.tensor) and not is_arg:
self._set_value(name, self.builder.get_insert_block(), value)
#
# SSA-construction
#
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
if not bb:
bb = self.builder.get_insert_block()
# local value numbering
if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)]
# global value numbering
saved_insert_point = self.builder.get_insert_point()
result = self._get_tensor_recursive(name, bb)
self.builder.set_insert_point(saved_insert_point)
return result
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
preds = bb.get_predecessors()
type = self.lscope[name].type
# some preds haven't been filled, create a phi as a proxy of the value
if bb not in self.sealed_blocks:
result = self._make_phi(type, len(preds), bb)
if bb in self.incomplete_phis:
self.incomplete_phis[bb][name] = result
else:
self.incomplete_phis[bb] = {name: result}
elif len(preds) == 1:
# one predecessor: no phi needed, try get value from pred
result = self._get_tensor(name, preds[0])
else: # multiple preds
assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)'
phi = self._make_phi(type, len(preds), bb)
self._set_value(name, bb, phi)
result = self._add_phi_operands(name, phi)
self._set_value(name, bb, result)
return result
# returns a new phi tensor, which encausulate an ir.phi_node
def _make_phi(self,
type: triton.language.dtype,
num_values: int,
bb: _triton.ir.basic_block) -> triton.language.tensor:
instr = bb.get_first_non_phi()
self.builder.set_insert_point((bb, instr))
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
if instr:
self.builder.set_insert_block(bb)
return triton.language.tensor(ir_phi, type)
# complete a phi node. (TODO: rename this as _complete_phis?)
# Note: since we try to remove tryival phi, the return tensor might not be a phi
def _add_phi_operands(self, name: str,
phi: triton.language.tensor) -> triton.language.tensor:
bb = phi.handle.get_parent()
for pred in bb.get_predecessors():
v = self._get_tensor(name, pred)
phi.handle.add_incoming(v.handle, pred)
phi = self._try_remove_trivial_phi(phi)
return phi
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
self.lvalues[(name, bb)] = value
# TODO: why we need this?
self.module.set_instr_metadata(name, value.handle)
def _seal_block(self, bb: _triton.ir.basic_block):
# complete all incomplete phis
if bb in self.incomplete_phis:
for name, phi in self.incomplete_phis[bb].items():
result = self._add_phi_operands(name, phi)
# it's possible that this phi is trivial
if self._get_tensor(name, bb).handle == phi.handle:
self._set_value(name, bb, result)
del self.incomplete_phis[bb]
self.sealed_blocks.add(bb)
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
if len(unique_handles) != 1: # non-trivial phi
return phi
v = unique_handles.pop()
phi.handle.replace_all_uses_with(v)
phi.handle.erase_from_parent()
# TODO: remove trivial phis recursively
return triton.language.tensor(v, phi.type)
def is_triton_tensor(self, value):
return isinstance(value, triton.language.tensor)
#
# AST visitor
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
def visit_Module(self, node): def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
@@ -220,7 +113,7 @@ class CodeGenerator(ast.NodeVisitor):
if inline: if inline:
pass pass
else: else:
fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = [] arg_values = []
idx = 0 idx = 0
for i, arg_name in enumerate(arg_names): for i, arg_name in enumerate(arg_names):
@@ -237,17 +130,17 @@ class CodeGenerator(ast.NodeVisitor):
attr = _triton.ir.attribute(attr, self.attributes[i]) attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(idx + 1, attr) fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name fn.args[idx].name = arg_name
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) arg_values.append(fn.args[idx])
idx += 1 idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value, is_arg=True) self.set_value(arg_name, arg_value)
if inline: if inline:
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
return self.last_ret return self.last_ret
else: else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self._seal_block(entry) self.module.seal_block(entry)
self.builder.set_insert_block(entry) self.builder.set_insert_block(entry)
# visit function body # visit function body
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
@@ -294,12 +187,11 @@ class CodeGenerator(ast.NodeVisitor):
if not isinstance(values, tuple): if not isinstance(values, tuple):
values = [values] values = [values]
for name, value in zip(names, values): for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable # by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr): if isinstance(value, triton.language.constexpr):
value = value.value value = value.value
if not isinstance(value, triton.language.tensor): if not isinstance(value, triton.language.block):
value = triton.language.core._to_tensor(value, self.builder) value = triton.language.core._to_ir(value, self.builder)
self.set_value(name, value) self.set_value(name, value)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
@@ -328,9 +220,9 @@ class CodeGenerator(ast.NodeVisitor):
def visit_BinOp(self, node): def visit_BinOp(self, node):
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
if isinstance(lhs, triton.language.constexpr): if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value lhs = lhs.value
if isinstance(rhs, triton.language.constexpr): if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value rhs = rhs.value
fn = { fn = {
ast.Add: '__add__', ast.Add: '__add__',
@@ -346,9 +238,9 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__', ast.BitOr: '__or__',
ast.BitXor: '__xor__', ast.BitXor: '__xor__',
}[type(node.op)] }[type(node.op)]
if self.is_triton_tensor(lhs): if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs): elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
@@ -356,15 +248,15 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor): if isinstance(cond, triton.language.block):
cond = cond.to(triton.language.int1, _builder=self.builder) cond = cond.to(triton.language.int1, _builder=self.builder)
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self._seal_block(then_bb) self.module.seal_block(then_bb)
if else_bb: if else_bb:
self._seal_block(else_bb) self.module.seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb)
else: else:
self.builder.cond_br(cond.handle, then_bb, endif_bb) self.builder.cond_br(cond.handle, then_bb, endif_bb)
@@ -379,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: last statement is a terminator? # TODO: last statement is a terminator?
if not is_terminator: if not is_terminator:
self.builder.br(endif_bb) self.builder.br(endif_bb)
self._seal_block(endif_bb) self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb) self.builder.set_insert_block(endif_bb)
else: else:
if isinstance(cond, triton.language.constexpr): if isinstance(cond, triton.language.constexpr):
@@ -404,9 +296,9 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1 assert len(node.ops) == 1
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0]) rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.constexpr): if isinstance(lhs, triton.language.core.constexpr):
lhs = lhs.value lhs = lhs.value
if isinstance(rhs, triton.language.constexpr): if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value rhs = rhs.value
if type(node.ops[0]) == ast.Is: if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs) return triton.language.constexpr(lhs is rhs)
@@ -420,9 +312,9 @@ class CodeGenerator(ast.NodeVisitor):
ast.Gt: '__gt__', ast.Gt: '__gt__',
ast.GtE: '__ge__', ast.GtE: '__ge__',
}[type(node.ops[0])] }[type(node.ops[0])]
if self.is_triton_tensor(lhs): if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs): elif self.is_triton_object(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
@@ -433,21 +325,21 @@ class CodeGenerator(ast.NodeVisitor):
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op) return triton.language.constexpr(not op)
if isinstance(op, triton.language.constexpr): if isinstance(op, triton.language.core.constexpr):
op = op.value op = op.value
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
ast.Invert: '__invert__', ast.Invert: '__invert__',
}[type(node.op)] }[type(node.op)]
if self.is_triton_tensor(op): if self.is_triton_object(op):
return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)() return getattr(op, fn)()
def visit_While(self, node): def visit_While(self, node):
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn(): def continue_fn():
cond = self.visit(node.test) cond = self.visit(node.test)
@@ -458,9 +350,9 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
continue_fn() continue_fn()
stop_bb = self.builder.get_insert_block() stop_bb = self.builder.get_insert_block()
self._seal_block(stop_bb) self.module.seal_block(stop_bb)
self._seal_block(loop_bb) self.module.seal_block(loop_bb)
self._seal_block(next_bb) self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb) self.builder.set_insert_block(next_bb)
for stmt in node.orelse: for stmt in node.orelse:
@@ -470,7 +362,7 @@ class CodeGenerator(ast.NodeVisitor):
assert node.ctx.__class__.__name__ == "Load" assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value) lhs = self.visit(node.value)
slices = self.visit(node.slice) slices = self.visit(node.slice)
if self.is_triton_tensor(lhs): if self.is_triton_object(lhs):
return lhs.__getitem__(slices, _builder=self.builder) return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices] return lhs[slices]
@@ -513,8 +405,8 @@ class CodeGenerator(ast.NodeVisitor):
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
def continue_fn(): def continue_fn():
self.visit(step_node) self.visit(step_node)
@@ -529,9 +421,9 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: handle case where body breaks control flow # TODO: handle case where body breaks control flow
continue_fn() continue_fn()
stop_bb = self.builder.get_insert_block() stop_bb = self.builder.get_insert_block()
self._seal_block(stop_bb) self.module.seal_block(stop_bb)
self._seal_block(loop_bb) self.module.seal_block(loop_bb)
self._seal_block(next_bb) self.module.seal_block(next_bb)
self.builder.set_insert_block(next_bb) self.builder.set_insert_block(next_bb)
for stmt in node.orelse: for stmt in node.orelse:
@@ -559,7 +451,7 @@ class CodeGenerator(ast.NodeVisitor):
args = [self.visit(arg) for arg in node.args] args = [self.visit(arg) for arg in node.args]
if isinstance(fn, JITFunction): if isinstance(fn, JITFunction):
return fn(*args, generator=self, **kws) return fn(*args, generator=self, **kws)
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
sys.modules[fn.__module__] is triton.language.core: sys.modules[fn.__module__] is triton.language.core:
return fn(*args, _builder=self.builder, **kws) return fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values(): if fn in self.builtins.values():
@@ -699,7 +591,7 @@ class Kernel:
} }
if hasattr(obj, 'data_ptr'): if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype] return type_names[obj.dtype]
if isinstance(obj, triton.language.constexpr): if isinstance(obj, triton.language.core.constexpr):
obj = obj.value obj = obj.value
if isinstance(obj, int): if isinstance(obj, int):
if -2**31 <= obj < 2**31: if -2**31 <= obj < 2**31:
@@ -731,34 +623,34 @@ class Kernel:
return 'scalar', name return 'scalar', name
@staticmethod @staticmethod
def _to_triton_ir(obj): def _to_triton_ir(context, obj):
which, name = obj which, name = obj
type_map = { type_map = {
'I': triton.language.int32, 'I': _triton.ir.type.get_int32,
'L': triton.language.int64, 'L': _triton.ir.type.get_int64,
'f': triton.language.float32, 'f': _triton.ir.type.get_fp32,
'B': triton.language.int1, 'B': _triton.ir.type.get_int1,
'f8': triton.language.float8, 'f8': _triton.ir.type.get_fp8,
'f16': triton.language.float16, 'f16': _triton.ir.type.get_fp16,
'bf16': triton.language.bfloat16, 'bf16': _triton.ir.type.get_bf16,
'f32': triton.language.float32, 'f32': _triton.ir.type.get_fp32,
'f64': triton.language.float64, 'f64': _triton.ir.type.get_fp64,
'i1': triton.language.int1, 'i1': _triton.ir.type.get_int1,
'i8': triton.language.int8, 'i8': _triton.ir.type.get_int8,
'i16': triton.language.int16, 'i16': _triton.ir.type.get_int16,
'i32': triton.language.int32, 'i32': _triton.ir.type.get_int32,
'i64': triton.language.int64, 'i64': _triton.ir.type.get_int64,
'u8': triton.language.uint8, 'u8': _triton.ir.type.get_uint8,
'u16': triton.language.uint16, 'u16': _triton.ir.type.get_uint16,
'u32': triton.language.uint32, 'u32': _triton.ir.type.get_uint32,
'u64': triton.language.uint64, 'u64': _triton.ir.type.get_uint64,
} }
# convert torch.Tensor to Triton IR pointers # convert torch.Tensor to Triton IR pointers
if which == 'ptr': if which == 'ptr':
elt_ty = type_map[name] elt_ty = type_map[name](context)
return triton.language.pointer_type(elt_ty, 1) return _triton.ir.type.make_ptr(elt_ty, 1)
# default path returns triton.ir.type directly # default path returns triton.ir.type directly
return type_map[name] return type_map[name](context)
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
@@ -1038,31 +930,25 @@ class JITFunction:
assert isinstance(tree.body[0], ast.FunctionDef) assert isinstance(tree.body[0], ast.FunctionDef)
return tree return tree
# Called by CodeGenerator.visit_Call()
def __call__(self, *args, generator: CodeGenerator, **kwargs): def __call__(self, *args, generator: CodeGenerator, **kwargs):
try: try:
from inspect import getcallargs from inspect import getcallargs
arg_values = getcallargs(self.fn, *args, **kwargs) arg_values = getcallargs(self.fn, *args, **kwargs)
arg_values = [arg_values[name] for name in self.arg_names] arg_values = [arg_values[name] for name in self.arg_names]
arg_values = [arg if isinstance(arg, triton.language.tensor) arg_values = [arg if isinstance(arg, triton.language.block)
else triton.language.constexpr(arg) for arg in arg_values] else triton.language.constexpr(arg) for arg in arg_values]
# Record values in the caller (parent scope)
gscope = generator.gscope.copy() gscope = generator.gscope.copy()
lscope = generator.lscope.copy() lscope = generator.lscope.copy()
values = generator.module.get_values().copy()
# TODO: clear values other than args types = generator.module.get_types().copy()
lvalues = generator.lvalues.copy()
# types = generator.module.get_types().copy()
generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.gscope = sys.modules[self.fn.__module__].__dict__
generator.lscope = dict() generator.lscope = dict()
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
generator.gscope = gscope generator.gscope = gscope
generator.lscope = lscope generator.lscope = lscope
generator.module.set_values(values)
generator.lvalues = lvalues generator.module.set_types(types)
# generator.module.set_types(types)
return ret return ret
except Exception as e: except Exception as e:
node = generator.last_node node = generator.last_node
@@ -1147,9 +1033,9 @@ class JITFunction:
# create IR module # create IR module
context = _triton.ir.context() context = _triton.ir.context()
# get just-in-time proto-type of kernel # get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
ret_type = triton.language.void ret_type = _triton.ir.type.get_void(context)
prototype = triton.language.function_type(ret_type, arg_types) prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR # generate Triton-IR
# export symbols visible from self into code-generator object # export symbols visible from self into code-generator object
gscope = self.__globals__ gscope = self.__globals__

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff