This reverts commit 539961072c
.
This commit is contained in:
@@ -38,8 +38,10 @@ public:
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_uint32(uint32_t val);
|
||||
value *get_uint64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
@@ -50,9 +52,11 @@ public:
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_uint8_ty();
|
||||
type *get_uint16_ty();
|
||||
type *get_uint32_ty();
|
||||
type *get_uint64_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
@@ -70,9 +74,7 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
@@ -91,11 +93,11 @@ public:
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
@@ -143,22 +145,11 @@ public:
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
|
@@ -26,6 +26,7 @@ public:
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
|
113
include/triton/ir/dispatch.h
Normal file
113
include/triton/ir/dispatch.h
Normal file
@@ -0,0 +1,113 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_DISPATCH_H_
|
||||
#define _TRITON_IR_DISPATCH_H_
|
||||
|
||||
#include "triton/ir/builder.h"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/*----------------------------------------------
|
||||
higher level functions that follow the likely
|
||||
semantics of most expected frontends
|
||||
----------------------------------------------*/
|
||||
|
||||
struct semantic_error: public std::runtime_error {
|
||||
semantic_error(const std::string& msg):
|
||||
std::runtime_error(msg) { }
|
||||
};
|
||||
|
||||
struct dispatch{
|
||||
typedef ir::type::block_shapes_t shape_t;
|
||||
|
||||
|
||||
// programming model
|
||||
static ir::value *program_id(int axis, ir::builder *builder);
|
||||
static ir::value *num_programs(int axis, ir::builder *builder);
|
||||
|
||||
// binary operators
|
||||
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
|
||||
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// unary operators
|
||||
static ir::value *plus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *minus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *invert(ir::value *input, ir::builder *builder);
|
||||
|
||||
// comparison operators
|
||||
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// block creation
|
||||
static ir::value* arange(int start, int end, ir::builder *builder);
|
||||
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
|
||||
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
|
||||
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
|
||||
|
||||
// indexing
|
||||
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
||||
|
||||
// reduction
|
||||
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *debug_barrier(ir::builder *builder);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -57,10 +57,26 @@ private:
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
module(const std::string &name, builder& builder);
|
||||
builder& get_builder();
|
||||
// Setters
|
||||
void set_value(const std::string& name, basic_block* block, value *x);
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_const(const std::string& name);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||
const std::map<std::string, type*>& get_types() { return types_; }
|
||||
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
|
||||
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||
const std::string& get_name();
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
@@ -73,14 +89,21 @@ public:
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder &builder_;
|
||||
builder& builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<std::string, type*> types_;
|
||||
std::set<std::string> const_;
|
||||
std::set<basic_block*> sealed_blocks_;
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
|
@@ -16,6 +16,8 @@ class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
enum class signedness { SIGNED, UNSIGNED };
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
@@ -59,6 +61,8 @@ public:
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
signedness get_integer_signedness() const;
|
||||
bool is_integer_signed() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
@@ -81,6 +85,9 @@ public:
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth, signedness sn) {
|
||||
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
|
||||
}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
@@ -108,6 +115,10 @@ public:
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
static integer_type *get_uint8_ty(context &ctx);
|
||||
static integer_type *get_uint16_ty(context &ctx);
|
||||
static integer_type *get_uint32_ty(context &ctx);
|
||||
static integer_type *get_uint64_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
@@ -134,7 +145,7 @@ public:
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
@@ -157,18 +168,21 @@ class integer_type: public type {
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
integer_type(context &ctx, unsigned bitwidth, signedness sn)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
signedness get_signedness() const { return signedness_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
signedness signedness_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
|
@@ -48,12 +48,18 @@ void builder::set_insert_point(basic_block *block){
|
||||
value *builder::get_int1(bool val)
|
||||
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_int32(uint32_t val)
|
||||
value *builder::get_int32(int32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(uint64_t val)
|
||||
value *builder::get_uint32(uint32_t val)
|
||||
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_int64(int64_t val)
|
||||
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_uint64(uint64_t val)
|
||||
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
|
||||
|
||||
value *builder::get_float16(float val)
|
||||
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
|
||||
|
||||
@@ -84,15 +90,21 @@ type *builder::get_int32_ty()
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_fp8_ty()
|
||||
{ return type::get_fp8_ty(ctx_); }
|
||||
type *builder::get_uint8_ty()
|
||||
{ return type::get_uint8_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint16_ty()
|
||||
{ return type::get_uint16_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint32_ty()
|
||||
{ return type::get_uint32_ty(ctx_); }
|
||||
|
||||
type *builder::get_uint64_ty()
|
||||
{ return type::get_uint64_ty(ctx_); }
|
||||
|
||||
type *builder::get_half_ty()
|
||||
{ return type::get_fp16_ty(ctx_); }
|
||||
|
||||
type *builder::get_bf16_ty()
|
||||
{ return type::get_bf16_ty(ctx_); }
|
||||
|
||||
type *builder::get_float_ty()
|
||||
{ return type::get_fp32_ty(ctx_); }
|
||||
|
||||
@@ -127,8 +139,6 @@ value *builder::create_ret_void() {
|
||||
return create_cast(OPCODE, src, dst_ty);\
|
||||
}
|
||||
|
||||
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
|
||||
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
|
||||
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
|
||||
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
|
||||
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
|
||||
@@ -321,28 +331,6 @@ value *builder::create_downcast(value *arg) {
|
||||
return insert(downcast_inst::create(arg));
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
|
||||
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
|
||||
return create_atomic_rmw(OPCODE, ptr, val, mask);\
|
||||
}
|
||||
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
|
||||
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -359,6 +347,9 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
|
||||
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg){
|
||||
return insert(exp_inst::create(arg));
|
||||
|
@@ -19,12 +19,18 @@ context_impl::context_impl(context &ctx)
|
||||
fp32_ty(ctx, type::FP32TyID),
|
||||
fp64_ty(ctx, type::FP64TyID),
|
||||
// integers
|
||||
int1_ty(ctx, 1),
|
||||
int8_ty(ctx, 8),
|
||||
int16_ty(ctx, 16),
|
||||
int32_ty(ctx, 32),
|
||||
int64_ty(ctx, 64),
|
||||
int128_ty(ctx, 128) {}
|
||||
int1_ty(ctx, 1, signedness::SIGNED),
|
||||
int8_ty(ctx, 8, signedness::SIGNED),
|
||||
int16_ty(ctx, 16, signedness::SIGNED),
|
||||
int32_ty(ctx, 32, signedness::SIGNED),
|
||||
int64_ty(ctx, 64, signedness::SIGNED),
|
||||
int128_ty(ctx, 128, signedness::SIGNED),
|
||||
uint8_ty(ctx, 8, signedness::UNSIGNED),
|
||||
uint16_ty(ctx, 16, signedness::UNSIGNED),
|
||||
uint32_ty(ctx, 32, signedness::UNSIGNED),
|
||||
uint64_ty(ctx, 64, signedness::UNSIGNED){
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// context
|
||||
|
882
lib/ir/dispatch.cc
Normal file
882
lib/ir/dispatch.cc
Normal file
@@ -0,0 +1,882 @@
|
||||
#include "triton/ir/dispatch.h"
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
|
||||
|
||||
[[ noreturn ]] void throw_unreachable(std::string key) {
|
||||
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
|
||||
"This is likely a bug on our side.");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Programming Model
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
|
||||
return builder->create_get_program_id(axis);
|
||||
}
|
||||
|
||||
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
|
||||
return builder->create_get_num_programs(axis);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implicit Casting Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
|
||||
int a_rank = a_ty->get_integer_bitwidth();
|
||||
int b_rank = b_ty->get_integer_bitwidth();
|
||||
auto a_sn = a_ty->get_integer_signedness();
|
||||
auto b_sn = b_ty->get_integer_signedness();
|
||||
// Rules for signedness taken from "Usual arithmetic conversions" on
|
||||
// https://en.cppreference.com/w/c/language/conversion.
|
||||
if (a_sn == b_sn) {
|
||||
return a_rank > b_rank ? a_ty : b_ty;
|
||||
} else if (a_sn == signedness::UNSIGNED) {
|
||||
return a_rank >= b_rank ? a_ty : b_ty;
|
||||
} else if (b_sn == signedness::UNSIGNED) {
|
||||
return b_rank >= a_rank ? b_ty : a_ty;
|
||||
} else {
|
||||
throw_unreachable("integer_promote");
|
||||
}
|
||||
}
|
||||
|
||||
enum class DivOrMod { NO, YES };
|
||||
|
||||
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
|
||||
context &ctx = a_ty->get_context();
|
||||
// 1) if one operand is double, the other is implicitly
|
||||
// converted to double
|
||||
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
|
||||
return type::get_fp64_ty(ctx);
|
||||
// 2) if one operand is float, the other is implicitly
|
||||
// converted to float
|
||||
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
|
||||
return type::get_fp32_ty(ctx);
|
||||
// 3 ) if one operand is half, the other is implicitly converted to half
|
||||
// unless we're doing / or %, which do not exist natively in PTX for fp16.
|
||||
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
|
||||
if (div_or_mod == DivOrMod::YES) {
|
||||
return type::get_fp32_ty(ctx);
|
||||
} else {
|
||||
return type::get_fp16_ty(ctx);
|
||||
}
|
||||
}
|
||||
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
|
||||
throw_unreachable("computation_type");
|
||||
// 4 ) both operands are integer and undergo
|
||||
// integer promotion
|
||||
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
return integer_promote(a_ty, b_ty);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
|
||||
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
|
||||
}
|
||||
|
||||
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
|
||||
|
||||
if(type_a->is_pointer_ty()){
|
||||
if(!allow_ptr_a)
|
||||
throw_incompatible_types(type_a, type_b);
|
||||
// T* + U* with T != U
|
||||
if(type_b->is_pointer_ty() && (type_a != type_b))
|
||||
throw_incompatible_types(type_a, type_b);
|
||||
// T* + float
|
||||
if(type_b->is_floating_point_ty())
|
||||
throw_incompatible_types(type_a, type_b);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
|
||||
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
|
||||
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
|
||||
// implicit broadcasting
|
||||
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
|
||||
// implicit typecasting
|
||||
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
|
||||
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
|
||||
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
|
||||
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
|
||||
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
|
||||
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
|
||||
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
|
||||
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
|
||||
}
|
||||
}
|
||||
|
||||
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, true, true);
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// offset + ptr
|
||||
// ptr + offset
|
||||
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
|
||||
std::swap(input, other);
|
||||
if (input_scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(input, {other});
|
||||
// float + float
|
||||
else if (input_scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fadd(input, other);
|
||||
// int + int
|
||||
else if (input_scalar_ty->is_integer_ty())
|
||||
return builder->create_add(input, other);
|
||||
throw_unreachable("add");
|
||||
}
|
||||
|
||||
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, true, false);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// ptr - offset
|
||||
if (scalar_ty->is_pointer_ty())
|
||||
return builder->create_gep(input, {dispatch::minus(other, builder)});
|
||||
// float + float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fsub(input, other);
|
||||
// int + int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_sub(input, other);
|
||||
throw_unreachable("sub");
|
||||
}
|
||||
|
||||
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float * float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fmul(input, other);
|
||||
// int * int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_mul(input, other);
|
||||
throw_unreachable("mul");
|
||||
}
|
||||
|
||||
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// float / int
|
||||
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
|
||||
other = cast(other, input_scalar_ty, builder);
|
||||
// int / float
|
||||
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
|
||||
input = cast(input, other_scalar_ty, builder);
|
||||
// int / int (cast to float32)
|
||||
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
|
||||
input = cast(input, builder->get_float_ty(), builder);
|
||||
other = cast(other, builder->get_float_ty(), builder);
|
||||
}
|
||||
// float / float (cast to highest exponent type)
|
||||
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
|
||||
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
|
||||
other = cast(other, input_scalar_ty, builder);
|
||||
else
|
||||
input = cast(input, other_scalar_ty, builder);
|
||||
}
|
||||
// unreachable
|
||||
else
|
||||
throw_unreachable("div");
|
||||
return builder->create_fdiv(input, other);
|
||||
}
|
||||
|
||||
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
|
||||
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
|
||||
input = dispatch::cast(input, ret_ty, builder);
|
||||
other = dispatch::cast(other, ret_ty, builder);
|
||||
if (ret_ty->is_integer_signed()) {
|
||||
return builder->create_sdiv(input, other);
|
||||
} else {
|
||||
return builder->create_udiv(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("floordiv");
|
||||
}
|
||||
|
||||
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
|
||||
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
|
||||
throw semantic_error("both operands of fdiv must have floating point scalar type");
|
||||
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
|
||||
ir::value* ret = builder->create_fdiv(input, other);
|
||||
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
|
||||
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
|
||||
return ret;
|
||||
}
|
||||
|
||||
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
|
||||
// float % int
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_frem(input, other);
|
||||
// int % int
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
|
||||
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
|
||||
}
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_srem(input, other);
|
||||
} else {
|
||||
return builder->create_urem(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("mod");
|
||||
}
|
||||
|
||||
|
||||
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder, false, false, false);
|
||||
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
|
||||
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
|
||||
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
|
||||
throw_incompatible_types(input_sca_ty, other_sca_ty);
|
||||
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
|
||||
if (ret_sca_ty != input_sca_ty)
|
||||
input = dispatch::cast(input, ret_sca_ty, builder);
|
||||
if (ret_sca_ty != other_sca_ty)
|
||||
other = dispatch::cast(other, ret_sca_ty, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_and(input, other);
|
||||
}
|
||||
|
||||
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_or(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_xor(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_lshr(input, other);
|
||||
}
|
||||
|
||||
|
||||
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
bitwise_op_type_checking(input, other, builder);
|
||||
return builder->create_shl(input, other);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
|
||||
return input;
|
||||
}
|
||||
|
||||
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
|
||||
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
|
||||
if(input_sca_ty->is_pointer_ty())
|
||||
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
|
||||
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
|
||||
return dispatch::sub(_0, input, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
|
||||
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
|
||||
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
|
||||
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
|
||||
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
|
||||
return dispatch::xor_(input, _1, builder);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Comparison Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float > float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGT(input, other);
|
||||
// int > int
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float >= float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOGE(input, other);
|
||||
// int >= int
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSGE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpUGE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("greater_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLT(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLT(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULT(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_than");
|
||||
}
|
||||
|
||||
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float < float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOLE(input, other);
|
||||
// int < int
|
||||
else if (scalar_ty->is_integer_ty()) {
|
||||
if (scalar_ty->is_integer_signed()) {
|
||||
return builder->create_icmpSLE(input, other);
|
||||
} else {
|
||||
return builder->create_icmpULE(input, other);
|
||||
}
|
||||
}
|
||||
throw_unreachable("less_equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpOEQ(input, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpEQ(input, other);
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
|
||||
binary_op_type_checking(input, other, builder);
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// float == float
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_fcmpUNE(input, other);
|
||||
// int == int
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_icmpNE(input, other);
|
||||
throw_unreachable("equal");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Block Creation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
|
||||
return builder->get_range(start, end);
|
||||
}
|
||||
|
||||
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
|
||||
ir::value *_0 = ir::constant::get_null_value(dtype);
|
||||
return builder->create_splat(_0, shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shape Manipulation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
|
||||
unsigned numel = 1;
|
||||
for(unsigned s: dst_shape) numel *= s;
|
||||
if(input->get_type()->get_tile_num_elements() != numel)
|
||||
throw semantic_error("cannot reshape block of different shape");
|
||||
return builder->create_reshape(input, dst_shape);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
|
||||
return builder->create_cat(lhs, rhs);
|
||||
}
|
||||
|
||||
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
|
||||
if (!input->get_type()->is_block_ty())
|
||||
return builder->create_splat(input, shape);
|
||||
auto src_shape = input->get_type()->get_block_shapes();
|
||||
if (src_shape.size() != shape.size())
|
||||
throw std::runtime_error("Cannot broadcast");
|
||||
if(shape == src_shape)
|
||||
return input;
|
||||
return builder->create_broadcast(input, shape);
|
||||
}
|
||||
|
||||
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
|
||||
// make_shape_compatible(block, scalar)
|
||||
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
|
||||
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(scalar, block)
|
||||
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
|
||||
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
|
||||
// make_shape_compatible(block, block)
|
||||
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
|
||||
auto lhs_shape = lhs_ty->get_block_shapes();
|
||||
auto rhs_shape = rhs_ty->get_block_shapes();
|
||||
if (lhs_shape.size() != rhs_shape.size())
|
||||
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
|
||||
ir::type::block_shapes_t ret_shape;
|
||||
for (size_t i = 0; i < lhs_shape.size(); ++i) {
|
||||
unsigned left = lhs_shape[i];
|
||||
unsigned right = rhs_shape[i];
|
||||
if (left == 1)
|
||||
ret_shape.push_back(right);
|
||||
else if (right == 1)
|
||||
ret_shape.push_back(left);
|
||||
else if (left == right)
|
||||
ret_shape.push_back(left);
|
||||
else
|
||||
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
|
||||
": " + std::to_string(left) + " and " + std::to_string(right));
|
||||
}
|
||||
if (lhs_shape != ret_shape)
|
||||
lhs = builder->create_broadcast(lhs, ret_shape);
|
||||
if (rhs_shape != ret_shape)
|
||||
rhs = builder->create_broadcast(rhs, ret_shape);
|
||||
}
|
||||
return std::make_tuple(lhs, rhs);
|
||||
}
|
||||
|
||||
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
|
||||
ir::type *src_ty = input->get_type();
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
|
||||
if(src_ty == dst_ty)
|
||||
return input;
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
|
||||
return cast(input, dst_ty, builder);
|
||||
// Bitcast
|
||||
int src_bits = src_sca_ty->get_primitive_size_in_bits();
|
||||
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
|
||||
if( src_bits!= dst_bits)
|
||||
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
|
||||
"to data-type of size " + std::to_string(dst_bits));
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
|
||||
ir::type *src_ty = input->get_type();
|
||||
if (src_ty->is_block_ty())
|
||||
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
|
||||
if(src_ty == dst_ty)
|
||||
return input;
|
||||
ir::type *src_sca_ty = src_ty->get_scalar_ty();
|
||||
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
|
||||
//
|
||||
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
|
||||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
|
||||
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
|
||||
}
|
||||
// FP Truncation
|
||||
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
|
||||
if (truncate_fp)
|
||||
return builder->create_fp_trunc(input, dst_ty);
|
||||
// FP Extension
|
||||
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
|
||||
dst_sca_ty->is_floating_point_ty() &&
|
||||
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
|
||||
if (ext_fp)
|
||||
return builder->create_fp_ext(input, dst_ty);
|
||||
// Int cast
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
|
||||
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
|
||||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
|
||||
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
|
||||
return builder->create_int_cast(input, dst_ty, sign_extend);
|
||||
}
|
||||
// Float -> Int
|
||||
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
|
||||
if(dst_sca_ty->is_bool_ty())
|
||||
return builder->create_fp_to_ui(input, dst_ty);
|
||||
else
|
||||
return builder->create_fp_to_si(input, dst_ty);
|
||||
}
|
||||
// int -> Float
|
||||
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
|
||||
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
|
||||
return builder->create_ui_to_fp(input, dst_ty);
|
||||
else
|
||||
return builder->create_si_to_fp(input, dst_ty);
|
||||
}
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){
|
||||
int bitwidth = dst_sca_ty->get_integer_bitwidth();
|
||||
if(bitwidth == 64)
|
||||
return builder->create_cast(ir::PtrToInt, input, dst_ty);
|
||||
if(bitwidth == 1)
|
||||
return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder),
|
||||
builder->get_int64(0),
|
||||
builder);
|
||||
}
|
||||
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::IntToPtr, input, dst_ty);
|
||||
// Ptr -> Ptr
|
||||
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
|
||||
return builder->create_cast(ir::BitCast, input, dst_ty);
|
||||
// * -> Bool
|
||||
if (dst_sca_ty->is_bool_ty()) {
|
||||
if (src_sca_ty->is_pointer_ty())
|
||||
input = cast(input, builder->get_int64_ty(), builder);
|
||||
ir::value *other = builder->get_int64(0);
|
||||
if (src_ty->is_bool_ty())
|
||||
other = builder->create_splat(other, src_ty->get_block_shapes());
|
||||
return builder->create_icmpNE(input, other);
|
||||
}
|
||||
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memory Operators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
if(mask)
|
||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||
if(other)
|
||||
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
if(other)
|
||||
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||
// treat bool* as int8*
|
||||
if(elt_ty == builder->get_int1_ty()){
|
||||
elt_ty = builder->get_int8_ty();
|
||||
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||
}
|
||||
// cache modifier
|
||||
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
|
||||
if (!cache_modifier.empty()) {
|
||||
if (cache_modifier == ".ca")
|
||||
cache = load_inst::CA;
|
||||
else if (cache_modifier == ".cg")
|
||||
cache = load_inst::CG;
|
||||
else
|
||||
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
|
||||
}
|
||||
// eviction policy
|
||||
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
|
||||
if(!eviction_policy.empty()){
|
||||
if (eviction_policy == "evict_last")
|
||||
eviction = load_inst::EVICT_LAST;
|
||||
else if(eviction_policy == "evict_first")
|
||||
eviction = load_inst::EVICT_FIRST;
|
||||
else
|
||||
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
|
||||
}
|
||||
|
||||
|
||||
if (!mask && !other)
|
||||
return builder->create_load(ptr, cache, eviction, is_volatile);
|
||||
if (!mask)
|
||||
throw std::runtime_error("`other` cannot be provided without `mask`");
|
||||
auto shape = ptr->get_type()->get_block_shapes();
|
||||
if(!other){
|
||||
other = ir::undef_value::get(elt_ty);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
|
||||
}
|
||||
|
||||
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
|
||||
if(mask)
|
||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
|
||||
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
|
||||
// treat bool* as int8*
|
||||
if(elt_ty == builder->get_int1_ty()){
|
||||
elt_ty = builder->get_int8_ty();
|
||||
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
|
||||
ptr = dispatch::cast(ptr, ptr_ty, builder);
|
||||
}
|
||||
// cast to target data-type
|
||||
val = dispatch::cast(val, elt_ty, builder);
|
||||
if (!mask)
|
||||
return builder->create_store(ptr, val);
|
||||
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
|
||||
throw semantic_error("Mask must have boolean scalar type");
|
||||
return builder->create_masked_store(ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
|
||||
return builder->create_atomic_cas(ptr, cmp, val);
|
||||
}
|
||||
|
||||
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
|
||||
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
|
||||
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
|
||||
if(ptr->get_type()->is_block_ty()){
|
||||
if(mask){
|
||||
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
if(val){
|
||||
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
}
|
||||
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
|
||||
if(!mask){
|
||||
mask = builder->get_int1(true);
|
||||
if(ptr->get_type()->is_block_ty())
|
||||
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
|
||||
}
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_max for integers
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umin(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
// direct call to atomic_min for integers
|
||||
if(sca_ty->is_integer_ty()) {
|
||||
if (sca_ty->is_integer_signed()) {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
|
||||
} else {
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
|
||||
}
|
||||
}
|
||||
// for float
|
||||
// return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
// return atomic_umax(i_ptr, i_val) if val < 0
|
||||
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
|
||||
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
|
||||
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
|
||||
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
|
||||
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
|
||||
return where(pos, pos_ret, neg_ret, builder);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
|
||||
return builder->create_atomic_rmw(op, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
|
||||
}
|
||||
|
||||
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
|
||||
atom_red_typechecking(ptr, val, mask, builder);
|
||||
ir::type* sca_ty = val->get_type()->get_scalar_ty();
|
||||
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linear Algebra
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
|
||||
ir::value *_0 = nullptr;
|
||||
if (lhs->get_type()->is_int_or_tileint_ty())
|
||||
_0 = builder->get_int32(0);
|
||||
else
|
||||
_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
_0 = builder->create_splat(_0, {M, N});
|
||||
bool _allow_tf32 = allow_tf32->get_value() != 0;
|
||||
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Indexing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
|
||||
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
|
||||
if(condition->get_type()->is_block_ty()){
|
||||
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
|
||||
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
|
||||
}
|
||||
ir::type* x_ty = x->get_type()->get_scalar_ty();
|
||||
ir::type* y_ty = y->get_type()->get_scalar_ty();
|
||||
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
|
||||
x = dispatch::cast(x, ty, builder);
|
||||
y = dispatch::cast(y, ty, builder);
|
||||
return builder->create_select(condition, x, y);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reductions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
|
||||
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
// input is extended to 32-bits if necessary
|
||||
// this increases numerical accuracy and can be done pretty much for free
|
||||
// on GPUs
|
||||
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
|
||||
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
|
||||
if (scalar_ty->is_floating_point_ty())
|
||||
return builder->create_reduce(input, FLOAT_OP, axis);
|
||||
else if (scalar_ty->is_integer_ty())
|
||||
return builder->create_reduce(input, INT_OP, axis);
|
||||
throw_unreachable(name);
|
||||
}
|
||||
|
||||
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
|
||||
}
|
||||
|
||||
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
|
||||
}
|
||||
|
||||
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
|
||||
}
|
||||
|
||||
ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) {
|
||||
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
|
||||
if (!scalar_ty->is_integer_ty())
|
||||
throw semantic_error("xor_sum only supported for integers");
|
||||
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
|
||||
binary_op_type_checking(x, y, builder);
|
||||
return builder->insert(umulhi_inst::create(x, y));
|
||||
}
|
||||
|
||||
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_exp(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_log(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_cos(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_sin(x);
|
||||
}
|
||||
|
||||
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
|
||||
return builder->create_sqrt(x);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
|
||||
if(!i)
|
||||
throw_unreachable("multiple_of");
|
||||
i->set_metadata(ir::metadata::multiple_of, value);
|
||||
return i;
|
||||
}
|
||||
|
||||
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
|
||||
if(!i)
|
||||
throw_unreachable("max_contiguous");
|
||||
i->set_metadata(ir::metadata::max_contiguous, value);
|
||||
return i;
|
||||
}
|
||||
|
||||
ir::value *dispatch::debug_barrier(ir::builder *builder) {
|
||||
return builder->create_barrier();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
140
lib/ir/module.cc
140
lib/ir/module.cc
@@ -9,6 +9,146 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* Module */
|
||||
module::module(const std::string &name, builder &builder)
|
||||
: name_(name), builder_(builder) {
|
||||
sealed_blocks_.insert(nullptr);
|
||||
}
|
||||
|
||||
ir::builder& module::get_builder() {
|
||||
return builder_;
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
values_[val_key_t{name, block}] = value;
|
||||
auto it = metadatas_.find(name);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
||||
if(it != metadatas_.end()){
|
||||
x->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
// value->set_name(name);
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
return set_value(name, builder_.get_insert_block(), value);
|
||||
}
|
||||
|
||||
void module::set_const(const std::string& name){
|
||||
const_.insert(name);
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
basic_block::iterator insert = block->get_first_non_phi();
|
||||
if(insert != block->end()){
|
||||
builder_.set_insert_point(insert);
|
||||
}
|
||||
ir::phi_node *res = builder_.create_phi(ty, num_values);
|
||||
if(insert != block->end())
|
||||
builder_.set_insert_point(block);
|
||||
return res;
|
||||
}
|
||||
|
||||
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
// find non-self references
|
||||
std::set<ir::value*> non_self_ref;
|
||||
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
|
||||
[phi](ir::value* op){ return op != phi && op; });
|
||||
// non-trivial
|
||||
if(non_self_ref.size() != 1)
|
||||
return phi;
|
||||
// unique value or self-reference
|
||||
ir::value *same = *non_self_ref.begin();
|
||||
assert(same != nullptr);
|
||||
phi->replace_all_uses_with(same);
|
||||
phi->erase_from_parent();
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
for(ir::user* u: users)
|
||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||
if(uphi != phi)
|
||||
try_remove_trivial_phis(uphi);
|
||||
return same;
|
||||
}
|
||||
|
||||
|
||||
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||
// already initialized
|
||||
if(phi->get_num_operands())
|
||||
return phi;
|
||||
ir::basic_block *block = phi->get_parent();
|
||||
for(ir::basic_block *pred: block->get_predecessors()){
|
||||
ir::value *value = get_value(name, pred);
|
||||
phi->add_incoming(value, pred);
|
||||
}
|
||||
return phi;
|
||||
}
|
||||
|
||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *result;
|
||||
bool is_const = const_.find(name) != const_.end();
|
||||
auto &preds = block->get_predecessors();
|
||||
ir::type *ty = types_.at(name);
|
||||
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||
result = (ir::value*)incomplete_phis_[block][name];
|
||||
}
|
||||
else if(preds.size() <= 1){
|
||||
bool has_pred = preds.size();
|
||||
result = get_value(name, has_pred?preds.front():nullptr);
|
||||
}
|
||||
else{
|
||||
ir::phi_node* phi = make_phi(ty, 1, block);
|
||||
set_value(name, block, phi);
|
||||
result = add_phi_operands(name, phi);
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
|
||||
result = try_remove_trivial_phis(phi);
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
|
||||
result = try_remove_trivial_phis(phi);
|
||||
}
|
||||
set_value(name, block, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::basic_block* save_block = builder_.get_insert_block();
|
||||
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
||||
val_key_t key(name, block);
|
||||
if(values_.find(key) != values_.end()){
|
||||
return values_.at(key);
|
||||
}
|
||||
ir::value *result = get_value_recursive(name, block);
|
||||
builder_.set_insert_point(save_block);
|
||||
if(save_pt != save_block->end())
|
||||
builder_.set_insert_point(save_pt);
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name) {
|
||||
return get_value(name, builder_.get_insert_block());
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
add_phi_operands(x.first, x.second);
|
||||
if(get_value(x.first) == x.second)
|
||||
set_value(x.first, try_remove_trivial_phis(x.second));
|
||||
}
|
||||
sealed_blocks_.insert(block);
|
||||
incomplete_phis_[block].clear();
|
||||
}
|
||||
|
||||
/* functions */
|
||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||
function *&fn = (function*&)symbols_[name];
|
||||
|
@@ -36,6 +36,16 @@ unsigned type::get_primitive_size_in_bits() const {
|
||||
unsigned type::get_integer_bitwidth() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
|
||||
|
||||
signedness type::get_integer_signedness() const
|
||||
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
|
||||
|
||||
bool type::is_integer_signed() const {
|
||||
if (id_ != IntegerTyID) {
|
||||
throw std::logic_error("type is " + repr() + ", not integer");
|
||||
}
|
||||
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
|
||||
}
|
||||
|
||||
unsigned type::get_tile_bitwidth() const
|
||||
{ return ((block_type*)(this))->get_bitwidth(); }
|
||||
|
||||
@@ -135,6 +145,10 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
|
||||
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
|
||||
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
|
||||
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
|
||||
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
|
||||
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
|
||||
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
|
||||
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }
|
||||
|
||||
|
||||
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/dispatch.h"
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -11,12 +12,10 @@
|
||||
#include <pybind11/buffer_info.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "Python.h"
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
@@ -542,6 +541,84 @@ void init_triton_codegen(py::module &&m) {
|
||||
}, py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* User-facing language features */
|
||||
/*****************************************************************************/
|
||||
|
||||
void init_triton_frontend(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
|
||||
// programming model
|
||||
m.def("program_id", &ir::dispatch::program_id, ret::reference);
|
||||
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
|
||||
// binary
|
||||
m.def("add", &ir::dispatch::add, ret::reference);
|
||||
m.def("sub", &ir::dispatch::sub, ret::reference);
|
||||
m.def("mul", &ir::dispatch::mul, ret::reference);
|
||||
m.def("truediv", &ir::dispatch::truediv, ret::reference);
|
||||
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
|
||||
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
|
||||
m.def("mod", &ir::dispatch::mod, ret::reference);
|
||||
m.def("and_", &ir::dispatch::and_, ret::reference);
|
||||
m.def("or_", &ir::dispatch::or_, ret::reference);
|
||||
m.def("xor_", &ir::dispatch::xor_, ret::reference);
|
||||
m.def("lshr", &ir::dispatch::lshr, ret::reference);
|
||||
m.def("shl", &ir::dispatch::shl, ret::reference);
|
||||
// unary
|
||||
m.def("plus", &ir::dispatch::plus, ret::reference);
|
||||
m.def("minus", &ir::dispatch::minus, ret::reference);
|
||||
m.def("invert", &ir::dispatch::invert, ret::reference);
|
||||
// comparison
|
||||
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
|
||||
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
|
||||
m.def("less_than", &ir::dispatch::less_than, ret::reference);
|
||||
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
|
||||
m.def("equal", &ir::dispatch::equal, ret::reference);
|
||||
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
|
||||
// block creation
|
||||
m.def("arange", &ir::dispatch::arange, ret::reference);
|
||||
m.def("zeros", &ir::dispatch::zeros, ret::reference);
|
||||
// type manipuatation
|
||||
m.def("cat", &ir::dispatch::cat, ret::reference);
|
||||
m.def("reshape", &ir::dispatch::reshape, ret::reference);
|
||||
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
|
||||
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
|
||||
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
|
||||
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
|
||||
m.def("cast", &ir::dispatch::cast, ret::reference);
|
||||
// memory
|
||||
m.def("load", &ir::dispatch::load, ret::reference);
|
||||
m.def("store", &ir::dispatch::store, ret::reference);
|
||||
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
|
||||
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
|
||||
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
|
||||
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
|
||||
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
|
||||
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
|
||||
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
|
||||
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
|
||||
// linear algebra
|
||||
m.def("dot", &ir::dispatch::dot, ret::reference);
|
||||
// indexing
|
||||
m.def("where", &ir::dispatch::where, ret::reference);
|
||||
// reduction
|
||||
m.def("min", &ir::dispatch::min, ret::reference);
|
||||
m.def("max", &ir::dispatch::max, ret::reference);
|
||||
m.def("sum", &ir::dispatch::sum, ret::reference);
|
||||
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
|
||||
// math
|
||||
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
|
||||
m.def("exp", &ir::dispatch::exp, ret::reference);
|
||||
m.def("log", &ir::dispatch::log, ret::reference);
|
||||
m.def("cos", &ir::dispatch::cos, ret::reference);
|
||||
m.def("sin", &ir::dispatch::sin, ret::reference);
|
||||
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
|
||||
// internal (debugging only)
|
||||
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
|
||||
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
|
||||
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
/* Python bindings for triton::ir */
|
||||
@@ -551,86 +628,16 @@ void init_triton_ir(py::module &&m) {
|
||||
using ret = py::return_value_policy;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
|
||||
.value("NONE", ir::load_inst::NONE)
|
||||
.value("CA", ir::load_inst::CA)
|
||||
.value("CG", ir::load_inst::CG)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
|
||||
.value("NORMAL", ir::load_inst::NORMAL)
|
||||
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
|
||||
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
|
||||
.export_values();
|
||||
|
||||
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
|
||||
.value("ADD", ir::reduce_inst::ADD)
|
||||
.value("FADD", ir::reduce_inst::FADD)
|
||||
.value("MIN", ir::reduce_inst::MIN)
|
||||
.value("MAX", ir::reduce_inst::MAX)
|
||||
.value("FMIN", ir::reduce_inst::FMIN)
|
||||
.value("FMAX", ir::reduce_inst::FMAX)
|
||||
.value("XOR", ir::reduce_inst::XOR);
|
||||
|
||||
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
|
||||
.value("ADD", ir::atomic_rmw_op_t::Add)
|
||||
.value("FADD", ir::atomic_rmw_op_t::FAdd)
|
||||
.value("AND", ir::atomic_rmw_op_t::And)
|
||||
.value("OR", ir::atomic_rmw_op_t::Or)
|
||||
.value("XOR", ir::atomic_rmw_op_t::Xor)
|
||||
.value("XCHG", ir::atomic_rmw_op_t::Xchg)
|
||||
.value("MAX", ir::atomic_rmw_op_t::Max)
|
||||
.value("MIN", ir::atomic_rmw_op_t::Min)
|
||||
.value("UMIN", ir::atomic_rmw_op_t::UMin)
|
||||
.value("UMAX", ir::atomic_rmw_op_t::UMax);
|
||||
|
||||
py::class_<ir::context>(m, "context")
|
||||
.def(py::init<>());
|
||||
|
||||
py::class_<ir::value>(m, "value")
|
||||
.def("multiple_of", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::multiple_of, val);
|
||||
} else
|
||||
throw std::runtime_error("multiple_of");
|
||||
})
|
||||
.def("max_contiguous", [](ir::value *self, int val) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
instr->set_metadata(ir::metadata::max_contiguous, val);
|
||||
} else
|
||||
throw std::runtime_error("max_contiguous");
|
||||
})
|
||||
.def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
|
||||
if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
|
||||
instr->set_fdiv_ieee_rounding(val);
|
||||
else
|
||||
throw std::runtime_error("set_fdiv_ieee_rounding");
|
||||
})
|
||||
.def("is_phi", [](ir::value *self) {
|
||||
if (auto *pn = dynamic_cast<ir::phi_node*>(self))
|
||||
return true;
|
||||
return false;
|
||||
})
|
||||
.def("ops", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
|
||||
return instr->ops();
|
||||
}
|
||||
throw std::runtime_error("cannot use ops()");
|
||||
})
|
||||
.def("replace_all_uses_with", &ir::value::replace_all_uses_with)
|
||||
.def("erase_from_parent", [](ir::value *self) {
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(self))
|
||||
return instr->erase_from_parent();
|
||||
throw std::runtime_error("cannot use erase_from_parent");
|
||||
})
|
||||
.def_property("name", &ir::value::get_name, &ir::value::set_name)
|
||||
.def_property_readonly("type", &ir::value::get_type);
|
||||
auto value = py::class_<ir::value>(m, "value");
|
||||
value.def_property("name", &ir::value::get_name, &ir::value::set_name);
|
||||
value.def_property_readonly("type", &ir::value::get_type);
|
||||
|
||||
py::class_<ir::user, ir::value>(m, "user");
|
||||
|
||||
py::class_<ir::constant, ir::user>(m, "constant")
|
||||
.def("get_null_value", &ir::constant::get_null_value, ret::reference)
|
||||
.def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
|
||||
py::class_<ir::constant, ir::user>(m, "constant");
|
||||
|
||||
py::class_<ir::undef_value, ir::constant>(m, "undef")
|
||||
.def("get", &ir::undef_value::get, ret::reference);
|
||||
@@ -641,17 +648,16 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
|
||||
|
||||
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value)
|
||||
.def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value);
|
||||
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction")
|
||||
.def("get_parent", [](ir::instruction *self) {
|
||||
return self->get_parent();
|
||||
}, ret::reference);
|
||||
py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
|
||||
.def("add_incoming", &ir::phi_node::add_incoming);
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction");
|
||||
py::class_<ir::phi_node, ir::user>(m, "phi_node");
|
||||
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("make_ptr", &ir::pointer_type::get, ret::reference)
|
||||
.def("make_function", &ir::function_type::get, ret::reference)
|
||||
.def("make_block", &ir::block_type::get, ret::reference)
|
||||
@@ -666,38 +672,34 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_int16", &ir::type::get_int16_ty, ret::reference)
|
||||
.def("get_int32", &ir::type::get_int32_ty, ret::reference)
|
||||
.def("get_int64", &ir::type::get_int64_ty, ret::reference)
|
||||
.def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference)
|
||||
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference)
|
||||
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
|
||||
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
|
||||
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
|
||||
|
||||
.def("get_block_shapes", &ir::type::get_block_shapes)
|
||||
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
.def("is_floating", &ir::type::is_floating_point_ty)
|
||||
.def("is_block", &ir::type::is_block_ty)
|
||||
.def("is_void", &ir::type::is_void_ty)
|
||||
.def("is_bool", &ir::type::is_bool_ty)
|
||||
.def("is_fp8", &ir::type::is_fp8_ty)
|
||||
.def("is_fp16", &ir::type::is_fp16_ty)
|
||||
.def("is_bf16", &ir::type::is_bf16_ty)
|
||||
.def("is_fp32", &ir::type::is_fp32_ty)
|
||||
.def("is_fp64", &ir::type::is_fp64_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
|
||||
.def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty)
|
||||
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); })
|
||||
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); })
|
||||
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); })
|
||||
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); })
|
||||
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); })
|
||||
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
|
||||
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
|
||||
|
||||
.def("repr", &ir::type::repr)
|
||||
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
|
||||
.def_property_readonly("scalar", &ir::type::get_scalar_ty)
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference)
|
||||
.def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
|
||||
.def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
|
||||
.def_property_readonly("context", &ir::type::get_context, ret::reference);
|
||||
|
||||
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
|
||||
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
|
||||
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
|
||||
|
||||
py::class_<ir::function_type, ir::type>(m, "function_type");
|
||||
py::class_<ir::integer_type, ir::type>(m, "integer_type");
|
||||
@@ -707,15 +709,16 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::module>(m, "module")
|
||||
.def(py::init<std::string, ir::builder &>())
|
||||
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
|
||||
const auto metadatas = self->get_metadatas();
|
||||
auto it = metadatas.find(name);
|
||||
if (it != metadatas.end())
|
||||
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
|
||||
instr->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
})
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference);
|
||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||
.def("seal_block", &ir::module::seal_block)
|
||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
||||
.def("set_type", &ir::module::set_type)
|
||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||
.def("get_values", &ir::module::get_values, ret::reference)
|
||||
.def("set_values", &ir::module::set_values)
|
||||
.def("get_types", &ir::module::get_types, ret::reference)
|
||||
.def("set_types", &ir::module::set_types)
|
||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||
|
||||
using eattr = ir::attribute_kind_t;
|
||||
py::enum_<eattr>(m, "attribute_kind")
|
||||
@@ -739,13 +742,6 @@ void init_triton_ir(py::module &&m) {
|
||||
|
||||
py::class_<ir::basic_block, ir::value>(m, "basic_block")
|
||||
.def("create", &ir::basic_block::create, ret::reference)
|
||||
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
|
||||
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
|
||||
ir::basic_block::iterator it = self->get_first_non_phi();
|
||||
if (it == self->end())
|
||||
return nullptr;
|
||||
return *it;
|
||||
}, ret::reference)
|
||||
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
||||
|
||||
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
|
||||
@@ -756,162 +752,17 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("br", &ir::builder::create_br, ret::reference)
|
||||
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// insertion block/point, insert points are represented as (*bb, *instr)
|
||||
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
|
||||
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
|
||||
.def("get_insert_point", [](ir::builder *self) {
|
||||
ir::basic_block *bb = self->get_insert_block();
|
||||
ir::basic_block::iterator it = self->get_insert_point();
|
||||
ir::instruction *instr = it == bb->end() ? nullptr : *it;
|
||||
return std::make_pair(bb, instr);
|
||||
}, ret::reference)
|
||||
.def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
|
||||
ir::basic_block *bb = pt.first;
|
||||
ir::instruction *instr = pt.second;
|
||||
if (instr) {
|
||||
if (bb != instr->get_parent())
|
||||
throw std::runtime_error("invalid insertion point, instr not in bb");
|
||||
self->set_insert_point(instr);
|
||||
} else {
|
||||
assert(bb);
|
||||
self->set_insert_point(bb);
|
||||
}
|
||||
})
|
||||
// Constants
|
||||
// constants
|
||||
.def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
.def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_int32", &ir::builder::get_int32, ret::reference)
|
||||
.def("get_int64", &ir::builder::get_int64, ret::reference)
|
||||
.def("get_uint32", &ir::builder::get_uint32, ret::reference)
|
||||
.def("get_uint64", &ir::builder::get_uint64, ret::reference)
|
||||
.def("get_float16", &ir::builder::get_float16, ret::reference)
|
||||
.def("get_float32", &ir::builder::get_float32, ret::reference)
|
||||
.def("get_range", &ir::builder::get_range, ret::reference)
|
||||
// Types
|
||||
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
|
||||
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
|
||||
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
|
||||
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
|
||||
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
|
||||
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
|
||||
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
|
||||
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
|
||||
.def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
|
||||
.def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
|
||||
.def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
|
||||
// terminator instructions
|
||||
.def("create_br", &ir::builder::create_br, ret::reference)
|
||||
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
|
||||
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
|
||||
// Cast instructions
|
||||
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
|
||||
.def("create_cast", &ir::builder::create_cast, ret::reference)
|
||||
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
|
||||
.def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
|
||||
.def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
|
||||
.def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
|
||||
.def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
|
||||
.def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
|
||||
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
|
||||
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
|
||||
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
// phi
|
||||
.def("create_phi", &ir::builder::create_phi, ret::reference)
|
||||
// Binary instructions
|
||||
.def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
|
||||
.def("create_fmul", &ir::builder::create_fmul, ret::reference)
|
||||
.def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
|
||||
.def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", &ir::builder::create_mul, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
|
||||
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
|
||||
.def("create_srem", &ir::builder::create_srem, ret::reference)
|
||||
.def("create_urem", &ir::builder::create_urem, ret::reference)
|
||||
.def("create_add", &ir::builder::create_add, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_sub", &ir::builder::create_sub, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_shl", &ir::builder::create_shl, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// GEP
|
||||
.def("create_gep", &ir::builder::create_gep, ret::reference)
|
||||
// Comparison (int)
|
||||
.def("create_icmp", &ir::builder::create_icmp, ret::reference)
|
||||
.def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
|
||||
.def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
|
||||
.def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
|
||||
.def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
|
||||
.def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
|
||||
.def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
|
||||
.def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
|
||||
.def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
|
||||
.def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
|
||||
.def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
|
||||
// Comparison (float)
|
||||
.def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
|
||||
.def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
|
||||
.def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
|
||||
.def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
|
||||
.def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
|
||||
.def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
|
||||
.def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
|
||||
.def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
|
||||
.def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
|
||||
.def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
|
||||
.def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
|
||||
.def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
|
||||
.def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
|
||||
// Logical
|
||||
.def("create_and", &ir::builder::create_and, ret::reference)
|
||||
.def("create_xor", &ir::builder::create_xor, ret::reference)
|
||||
.def("create_or", &ir::builder::create_or, ret::reference)
|
||||
// Input/Output
|
||||
.def("create_load", &ir::builder::create_load, ret::reference)
|
||||
.def("create_store", &ir::builder::create_store, ret::reference)
|
||||
.def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
|
||||
.def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
|
||||
// Block instruction
|
||||
.def("create_splat", &ir::builder::create_splat, ret::reference)
|
||||
.def("create_reshape", &ir::builder::create_reshape, ret::reference)
|
||||
.def("create_cat", &ir::builder::create_cat, ret::reference)
|
||||
.def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
|
||||
// atomic
|
||||
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
|
||||
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
|
||||
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
|
||||
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
|
||||
.def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
.def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
.def("create_sin", &ir::builder::create_sin, ret::reference)
|
||||
.def("create_log", &ir::builder::create_log, ret::reference)
|
||||
.def("create_dot", &ir::builder::create_dot, ret::reference)
|
||||
.def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
.def("create_select", &ir::builder::create_select, ret::reference)
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
.def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
|
||||
.def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
|
||||
.def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
|
||||
.def("create_barrier", &ir::builder::create_barrier, ret::reference)
|
||||
.def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
|
||||
.def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
|
||||
.def("get_range", &ir::builder::get_range, ret::reference);
|
||||
}
|
||||
|
||||
void init_triton(py::module &m) {
|
||||
@@ -919,4 +770,5 @@ void init_triton(py::module &m) {
|
||||
init_triton_codegen(std::move(subm.def_submodule("code_gen")));
|
||||
init_triton_runtime(std::move(subm.def_submodule("runtime")));
|
||||
init_triton_ir(std::move(subm.def_submodule("ir")));
|
||||
init_triton_frontend(std::move(subm.def_submodule("frontend")));
|
||||
}
|
||||
|
@@ -37,7 +37,7 @@ matmul_data = {
|
||||
(256, 256, 256): {'float16': 0.027},
|
||||
(512, 512, 512): {'float16': 0.158},
|
||||
(1024, 1024, 1024): {'float16': 0.466},
|
||||
(2048, 2048, 2048): {'float16': 0.695},
|
||||
(2048, 2048, 2048): {'float16': 0.680},
|
||||
(4096, 4096, 4096): {'float16': 0.831},
|
||||
(8192, 8192, 8192): {'float16': 0.849},
|
||||
# tall-skinny
|
||||
|
@@ -1,4 +1,5 @@
|
||||
# flake8: noqa: F821,F841
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
@@ -584,6 +585,7 @@ def test_f8_f16_roundtrip():
|
||||
|
||||
f8_output_tensor = torch.empty_like(f16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
print(f16.dtype, f8_output.dtype)
|
||||
copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
@@ -991,6 +993,27 @@ def test_noop(device='cuda'):
|
||||
kernel[(1, )](x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
pgm = kernel[(1, )](value, x)
|
||||
|
||||
# Parse out the type of the 'VALUE' parameter from the Triton IR.
|
||||
triton_ir = pgm.asm['ttir']
|
||||
ir_value_match = re.match(r'\s*def void kernel\((\w+) VALUE ', triton_ir)
|
||||
ir_value_type = None if ir_value_match is None else ir_value_match.group(1)
|
||||
assert ir_value_type == value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, overflow",
|
||||
[(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
@@ -103,30 +102,3 @@ def test_specialize(mode):
|
||||
for i in [1, 2, 4, 8, 16, 32]:
|
||||
function[(1,)](x, i, BLOCK=512)
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
|
||||
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
|
||||
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs['key'].split('-')
|
||||
triton.code_gen.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.code_gen.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
@@ -6,8 +6,7 @@ __version__ = '2.0.0'
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch
|
||||
# submodules
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
|
||||
JITFunction, Config, Autotuner, reinterpret
|
||||
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret
|
||||
from . import language
|
||||
from . import code_gen
|
||||
from . import testing
|
||||
|
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
import functools
|
||||
@@ -13,7 +11,7 @@ import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict, Optional, Set, Tuple, Union
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
@@ -24,13 +22,48 @@ from .tools.disasm import extract
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def get_value(self, name):
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if isinstance(ret, triton.language.block):
|
||||
handle = self.module.get_value(name)
|
||||
return triton.language.block(handle)
|
||||
return ret
|
||||
|
||||
def set_value(self, name, value):
|
||||
if isinstance(value, _triton.ir.value):
|
||||
value = triton.language.block(value)
|
||||
if isinstance(value, triton.language.block):
|
||||
self.module.set_value(name, value.handle)
|
||||
self.module.set_type(name, value.handle.type)
|
||||
self.lscope[name] = value
|
||||
|
||||
def is_triton_object(self, value):
|
||||
return isinstance(value, triton.language.block)
|
||||
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = _triton.ir.module('', self.builder)
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.is_arg_lscope = dict() # name => is_arg: {str: bool}
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
@@ -44,146 +77,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
# SSA-construction
|
||||
# [name, bb] => triton.language.tensor
|
||||
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
|
||||
# bb => {name => phi}
|
||||
self.incomplete_phis = {}
|
||||
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
|
||||
|
||||
def get_value(self, name):
|
||||
''' This function:
|
||||
1. make sure `name` is defined
|
||||
2. if `name` is triton.language.tensor, get stored tensor by calling
|
||||
`self._get_tensor()`
|
||||
'''
|
||||
# search node.id in local scope
|
||||
ret = None
|
||||
if name in self.lscope:
|
||||
ret = self.lscope[name]
|
||||
# search node.id in global scope
|
||||
elif name in self.gscope:
|
||||
ret = self.gscope[name]
|
||||
# search node.id in builtins
|
||||
elif name in self.builtins:
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]:
|
||||
return self._get_tensor(name)
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr],
|
||||
is_arg: bool = False) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
# if this value is an argument, we don't need to create phis for it
|
||||
self.is_arg_lscope[name] = is_arg
|
||||
if isinstance(value, triton.language.tensor) and not is_arg:
|
||||
self._set_value(name, self.builder.get_insert_block(), value)
|
||||
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
||||
if not bb:
|
||||
bb = self.builder.get_insert_block()
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
# global value numbering
|
||||
saved_insert_point = self.builder.get_insert_point()
|
||||
result = self._get_tensor_recursive(name, bb)
|
||||
self.builder.set_insert_point(saved_insert_point)
|
||||
return result
|
||||
|
||||
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
preds = bb.get_predecessors()
|
||||
type = self.lscope[name].type
|
||||
# some preds haven't been filled, create a phi as a proxy of the value
|
||||
if bb not in self.sealed_blocks:
|
||||
result = self._make_phi(type, len(preds), bb)
|
||||
if bb in self.incomplete_phis:
|
||||
self.incomplete_phis[bb][name] = result
|
||||
else:
|
||||
self.incomplete_phis[bb] = {name: result}
|
||||
elif len(preds) == 1:
|
||||
# one predecessor: no phi needed, try get value from pred
|
||||
result = self._get_tensor(name, preds[0])
|
||||
else: # multiple preds
|
||||
assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)'
|
||||
phi = self._make_phi(type, len(preds), bb)
|
||||
self._set_value(name, bb, phi)
|
||||
result = self._add_phi_operands(name, phi)
|
||||
self._set_value(name, bb, result)
|
||||
return result
|
||||
|
||||
# returns a new phi tensor, which encausulate an ir.phi_node
|
||||
def _make_phi(self,
|
||||
type: triton.language.dtype,
|
||||
num_values: int,
|
||||
bb: _triton.ir.basic_block) -> triton.language.tensor:
|
||||
instr = bb.get_first_non_phi()
|
||||
self.builder.set_insert_point((bb, instr))
|
||||
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
|
||||
if instr:
|
||||
self.builder.set_insert_block(bb)
|
||||
return triton.language.tensor(ir_phi, type)
|
||||
|
||||
# complete a phi node. (TODO: rename this as _complete_phis?)
|
||||
# Note: since we try to remove tryival phi, the return tensor might not be a phi
|
||||
def _add_phi_operands(self, name: str,
|
||||
phi: triton.language.tensor) -> triton.language.tensor:
|
||||
bb = phi.handle.get_parent()
|
||||
for pred in bb.get_predecessors():
|
||||
v = self._get_tensor(name, pred)
|
||||
phi.handle.add_incoming(v.handle, pred)
|
||||
phi = self._try_remove_trivial_phi(phi)
|
||||
return phi
|
||||
|
||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||
self.lvalues[(name, bb)] = value
|
||||
# TODO: why we need this?
|
||||
self.module.set_instr_metadata(name, value.handle)
|
||||
|
||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||
# complete all incomplete phis
|
||||
if bb in self.incomplete_phis:
|
||||
for name, phi in self.incomplete_phis[bb].items():
|
||||
result = self._add_phi_operands(name, phi)
|
||||
# it's possible that this phi is trivial
|
||||
if self._get_tensor(name, bb).handle == phi.handle:
|
||||
self._set_value(name, bb, result)
|
||||
del self.incomplete_phis[bb]
|
||||
self.sealed_blocks.add(bb)
|
||||
|
||||
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
|
||||
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
|
||||
if len(unique_handles) != 1: # non-trivial phi
|
||||
return phi
|
||||
v = unique_handles.pop()
|
||||
phi.handle.replace_all_uses_with(v)
|
||||
phi.handle.erase_from_parent()
|
||||
# TODO: remove trivial phis recursively
|
||||
return triton.language.tensor(v, phi.type)
|
||||
|
||||
def is_triton_tensor(self, value):
|
||||
return isinstance(value, triton.language.tensor)
|
||||
|
||||
#
|
||||
# AST visitor
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def visit_Module(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@@ -220,7 +113,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if inline:
|
||||
pass
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder))
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -237,17 +130,17 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value, is_arg=True)
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self._seal_block(entry)
|
||||
self.module.seal_block(entry)
|
||||
self.builder.set_insert_block(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
@@ -294,12 +187,11 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not isinstance(values, tuple):
|
||||
values = [values]
|
||||
for name, value in zip(names, values):
|
||||
# TODO: can we store constexpr here to support constant folding?
|
||||
# by default, constexpr are assigned into python variable
|
||||
if isinstance(value, triton.language.constexpr):
|
||||
value = value.value
|
||||
if not isinstance(value, triton.language.tensor):
|
||||
value = triton.language.core._to_tensor(value, self.builder)
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language.core._to_ir(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
@@ -328,9 +220,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_BinOp(self, node):
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
rhs = rhs.value
|
||||
fn = {
|
||||
ast.Add: '__add__',
|
||||
@@ -346,9 +238,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.BitOr: '__or__',
|
||||
ast.BitXor: '__xor__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(lhs):
|
||||
if self.is_triton_object(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
elif self.is_triton_object(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -356,15 +248,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_If(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
if isinstance(cond, triton.language.block):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
current_bb = self.builder.get_insert_block()
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
self._seal_block(then_bb)
|
||||
self.module.seal_block(then_bb)
|
||||
if else_bb:
|
||||
self._seal_block(else_bb)
|
||||
self.module.seal_block(else_bb)
|
||||
self.builder.cond_br(cond.handle, then_bb, else_bb)
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
@@ -379,7 +271,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: last statement is a terminator?
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self._seal_block(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
if isinstance(cond, triton.language.constexpr):
|
||||
@@ -404,9 +296,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(node.ops) == 1
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.comparators[0])
|
||||
if isinstance(lhs, triton.language.constexpr):
|
||||
if isinstance(lhs, triton.language.core.constexpr):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.constexpr):
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
rhs = rhs.value
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
@@ -420,9 +312,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_tensor(lhs):
|
||||
if self.is_triton_object(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
elif self.is_triton_tensor(rhs):
|
||||
elif self.is_triton_object(rhs):
|
||||
fn = fn[:2] + 'r' + fn[2:]
|
||||
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||
else:
|
||||
@@ -433,21 +325,21 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.constexpr):
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
ast.Invert: '__invert__',
|
||||
}[type(node.op)]
|
||||
if self.is_triton_tensor(op):
|
||||
if self.is_triton_object(op):
|
||||
return getattr(op, fn)(_builder=self.builder)
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
cond = self.visit(node.test)
|
||||
@@ -458,9 +350,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -470,7 +362,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert node.ctx.__class__.__name__ == "Load"
|
||||
lhs = self.visit(node.value)
|
||||
slices = self.visit(node.slice)
|
||||
if self.is_triton_tensor(lhs):
|
||||
if self.is_triton_object(lhs):
|
||||
return lhs.__getitem__(slices, _builder=self.builder)
|
||||
return lhs[slices]
|
||||
|
||||
@@ -513,8 +405,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
def continue_fn():
|
||||
self.visit(step_node)
|
||||
@@ -529,9 +421,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
self.module.seal_block(stop_bb)
|
||||
self.module.seal_block(loop_bb)
|
||||
self.module.seal_block(next_bb)
|
||||
self.builder.set_insert_block(next_bb)
|
||||
|
||||
for stmt in node.orelse:
|
||||
@@ -559,7 +451,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if isinstance(fn, JITFunction):
|
||||
return fn(*args, generator=self, **kws)
|
||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \
|
||||
sys.modules[fn.__module__] is triton.language.core:
|
||||
return fn(*args, _builder=self.builder, **kws)
|
||||
if fn in self.builtins.values():
|
||||
@@ -699,7 +591,7 @@ class Kernel:
|
||||
}
|
||||
if hasattr(obj, 'data_ptr'):
|
||||
return type_names[obj.dtype]
|
||||
if isinstance(obj, triton.language.constexpr):
|
||||
if isinstance(obj, triton.language.core.constexpr):
|
||||
obj = obj.value
|
||||
if isinstance(obj, int):
|
||||
if -2**31 <= obj < 2**31:
|
||||
@@ -731,34 +623,34 @@ class Kernel:
|
||||
return 'scalar', name
|
||||
|
||||
@staticmethod
|
||||
def _to_triton_ir(obj):
|
||||
def _to_triton_ir(context, obj):
|
||||
which, name = obj
|
||||
type_map = {
|
||||
'I': triton.language.int32,
|
||||
'L': triton.language.int64,
|
||||
'f': triton.language.float32,
|
||||
'B': triton.language.int1,
|
||||
'f8': triton.language.float8,
|
||||
'f16': triton.language.float16,
|
||||
'bf16': triton.language.bfloat16,
|
||||
'f32': triton.language.float32,
|
||||
'f64': triton.language.float64,
|
||||
'i1': triton.language.int1,
|
||||
'i8': triton.language.int8,
|
||||
'i16': triton.language.int16,
|
||||
'i32': triton.language.int32,
|
||||
'i64': triton.language.int64,
|
||||
'u8': triton.language.uint8,
|
||||
'u16': triton.language.uint16,
|
||||
'u32': triton.language.uint32,
|
||||
'u64': triton.language.uint64,
|
||||
'I': _triton.ir.type.get_int32,
|
||||
'L': _triton.ir.type.get_int64,
|
||||
'f': _triton.ir.type.get_fp32,
|
||||
'B': _triton.ir.type.get_int1,
|
||||
'f8': _triton.ir.type.get_fp8,
|
||||
'f16': _triton.ir.type.get_fp16,
|
||||
'bf16': _triton.ir.type.get_bf16,
|
||||
'f32': _triton.ir.type.get_fp32,
|
||||
'f64': _triton.ir.type.get_fp64,
|
||||
'i1': _triton.ir.type.get_int1,
|
||||
'i8': _triton.ir.type.get_int8,
|
||||
'i16': _triton.ir.type.get_int16,
|
||||
'i32': _triton.ir.type.get_int32,
|
||||
'i64': _triton.ir.type.get_int64,
|
||||
'u8': _triton.ir.type.get_uint8,
|
||||
'u16': _triton.ir.type.get_uint16,
|
||||
'u32': _triton.ir.type.get_uint32,
|
||||
'u64': _triton.ir.type.get_uint64,
|
||||
}
|
||||
# convert torch.Tensor to Triton IR pointers
|
||||
if which == 'ptr':
|
||||
elt_ty = type_map[name]
|
||||
return triton.language.pointer_type(elt_ty, 1)
|
||||
elt_ty = type_map[name](context)
|
||||
return _triton.ir.type.make_ptr(elt_ty, 1)
|
||||
# default path returns triton.ir.type directly
|
||||
return type_map[name]
|
||||
return type_map[name](context)
|
||||
|
||||
@staticmethod
|
||||
def pow2_divisor(N):
|
||||
@@ -1038,31 +930,25 @@ class JITFunction:
|
||||
assert isinstance(tree.body[0], ast.FunctionDef)
|
||||
return tree
|
||||
|
||||
# Called by CodeGenerator.visit_Call()
|
||||
def __call__(self, *args, generator: CodeGenerator, **kwargs):
|
||||
try:
|
||||
from inspect import getcallargs
|
||||
arg_values = getcallargs(self.fn, *args, **kwargs)
|
||||
arg_values = [arg_values[name] for name in self.arg_names]
|
||||
arg_values = [arg if isinstance(arg, triton.language.tensor)
|
||||
arg_values = [arg if isinstance(arg, triton.language.block)
|
||||
else triton.language.constexpr(arg) for arg in arg_values]
|
||||
|
||||
# Record values in the caller (parent scope)
|
||||
gscope = generator.gscope.copy()
|
||||
lscope = generator.lscope.copy()
|
||||
|
||||
# TODO: clear values other than args
|
||||
lvalues = generator.lvalues.copy()
|
||||
# types = generator.module.get_types().copy()
|
||||
values = generator.module.get_values().copy()
|
||||
types = generator.module.get_types().copy()
|
||||
generator.gscope = sys.modules[self.fn.__module__].__dict__
|
||||
generator.lscope = dict()
|
||||
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values)
|
||||
generator.gscope = gscope
|
||||
generator.lscope = lscope
|
||||
|
||||
generator.lvalues = lvalues
|
||||
# generator.module.set_types(types)
|
||||
|
||||
generator.module.set_values(values)
|
||||
generator.module.set_types(types)
|
||||
return ret
|
||||
except Exception as e:
|
||||
node = generator.last_node
|
||||
@@ -1147,9 +1033,9 @@ class JITFunction:
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types]
|
||||
ret_type = _triton.ir.type.get_void(context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user