[FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py Integer signedness now moved from C++ to python Cleaner frontend type Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -38,10 +38,8 @@ public:
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_uint32(uint32_t val);
|
||||
value *get_uint64(uint64_t val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
@@ -52,11 +50,9 @@ public:
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_uint8_ty();
|
||||
type *get_uint16_ty();
|
||||
type *get_uint32_ty();
|
||||
type *get_uint64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
@@ -74,7 +70,9 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
@@ -93,11 +91,11 @@ public:
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
@@ -145,11 +143,22 @@ public:
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
|
@@ -26,7 +26,6 @@ public:
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
|
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_DISPATCH_H_
|
||||
#define _TRITON_IR_DISPATCH_H_
|
||||
|
||||
#include "triton/ir/builder.h"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/*----------------------------------------------
|
||||
higher level functions that follow the likely
|
||||
semantics of most expected frontends
|
||||
----------------------------------------------*/
|
||||
|
||||
struct semantic_error: public std::runtime_error {
|
||||
semantic_error(const std::string& msg):
|
||||
std::runtime_error(msg) { }
|
||||
};
|
||||
|
||||
struct dispatch{
|
||||
typedef ir::type::block_shapes_t shape_t;
|
||||
|
||||
|
||||
// programming model
|
||||
static ir::value *program_id(int axis, ir::builder *builder);
|
||||
static ir::value *num_programs(int axis, ir::builder *builder);
|
||||
|
||||
// binary operators
|
||||
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
|
||||
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// unary operators
|
||||
static ir::value *plus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *minus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *invert(ir::value *input, ir::builder *builder);
|
||||
|
||||
// comparison operators
|
||||
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// block creation
|
||||
static ir::value* arange(int start, int end, ir::builder *builder);
|
||||
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
|
||||
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
|
||||
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
|
||||
|
||||
// indexing
|
||||
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
||||
|
||||
// reduction
|
||||
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *debug_barrier(ir::builder *builder);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -57,26 +57,10 @@ private:
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder& builder);
|
||||
builder& get_builder();
|
||||
// Setters
|
||||
void set_value(const std::string& name, basic_block* block, value *x);
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_const(const std::string& name);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||
const std::map<std::string, type*>& get_types() { return types_; }
|
||||
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||
void set_types(const std::map<std::string, type*>& types) { types_ = types; }
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||
const std::string& get_name();
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
@@ -89,21 +73,14 @@ public:
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder& builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<std::string, type*> types_;
|
||||
std::set<std::string> const_;
|
||||
std::set<basic_block*> sealed_blocks_;
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
builder &builder_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
|
@@ -16,8 +16,6 @@ class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
enum class signedness { SIGNED, UNSIGNED };
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
@@ -61,8 +59,6 @@ public:
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
signedness get_integer_signedness() const;
|
||||
bool is_integer_signed() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
@@ -85,9 +81,6 @@ public:
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth, signedness sn) {
|
||||
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
|
||||
}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
@@ -115,10 +108,6 @@ public:
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
static integer_type *get_uint8_ty(context &ctx);
|
||||
static integer_type *get_uint16_ty(context &ctx);
|
||||
static integer_type *get_uint32_ty(context &ctx);
|
||||
static integer_type *get_uint64_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
@@ -145,7 +134,7 @@ public:
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
@@ -168,21 +157,18 @@ class integer_type: public type {
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth, signedness sn)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
signedness get_signedness() const { return signedness_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
signedness signedness_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
|
Reference in New Issue
Block a user