[FRONTEND] Semantic analysis refactor (#491)

Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
Philippe Tillet
2022-04-06 16:13:53 -07:00
committed by GitHub
parent 2bed6fc850
commit 9f08ecd684
19 changed files with 2174 additions and 1745 deletions

View File

@@ -41,10 +41,8 @@ public:
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_uint32(uint32_t val);
value *get_uint64(uint64_t val);
value *get_int32(uint32_t val);
value *get_int64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
@@ -55,11 +53,9 @@ public:
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_uint8_ty();
type *get_uint16_ty();
type *get_uint32_ty();
type *get_uint64_ty();
type *get_fp8_ty();
type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
@@ -78,7 +74,9 @@ public:
value* create_ret_void();
value* create_ret(value *ret);
// Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
@@ -100,11 +98,11 @@ public:
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
@@ -155,11 +153,25 @@ public:
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);

View File

@@ -26,7 +26,6 @@ public:
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types

View File

@@ -1,117 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
// math
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// utilities
static ir::value *globaltimer(ir::builder *builder);
static ir::value *clock(ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -36,7 +36,6 @@ class alloc_const;
class value_constructor {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
@@ -57,7 +56,6 @@ public:
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
private:
ir::builder& builder_;
@@ -66,13 +64,13 @@ private:
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_;
std::map<std::string, md_pair_t> metadatas_;
};
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
friend class function;
public:
@@ -83,13 +81,10 @@ private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
builder &get_builder() { return builder_; };
const std::string& get_name() { return name_; };
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
@@ -114,17 +109,19 @@ public:
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
//
// Metadata
void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
private:
std::string name_;
builder& builder_;
builder &builder_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
};
}

View File

@@ -16,8 +16,6 @@ class value;
class integer_type;
class constant_int;
enum class signedness { SIGNED, UNSIGNED };
/* Type */
class type {
public:
@@ -61,8 +59,6 @@ public:
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
signedness get_integer_signedness() const;
bool is_integer_signed() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
@@ -87,9 +83,6 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth, signedness sn) {
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
@@ -118,10 +111,6 @@ public:
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
static integer_type *get_uint8_ty(context &ctx);
static integer_type *get_uint16_ty(context &ctx);
static integer_type *get_uint32_ty(context &ctx);
static integer_type *get_uint64_ty(context &ctx);
// repr
std::string tile_repr() const {
@@ -148,7 +137,7 @@ public:
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth());
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
@@ -171,21 +160,18 @@ class integer_type: public type {
private:
// constructors
integer_type(context &ctx, unsigned bitwidth, signedness sn)
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ }
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
signedness get_signedness() const { return signedness_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
signedness signedness_;
};
class composite_type: public type{