[FRONTEND] Semantic analysis refactor (#491)

Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
Philippe Tillet
2022-04-06 16:13:53 -07:00
committed by GitHub
parent 2bed6fc850
commit 9f08ecd684
19 changed files with 2174 additions and 1745 deletions

View File

@@ -41,10 +41,8 @@ public:
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}
// Constants // Constants
value *get_int1(bool val); value *get_int1(bool val);
value *get_int32(int32_t val); value *get_int32(uint32_t val);
value *get_int64(int64_t val); value *get_int64(uint64_t val);
value *get_uint32(uint32_t val);
value *get_uint64(uint64_t val);
value *get_float16(float val); value *get_float16(float val);
value *get_float32(float val); value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi); value *get_range(int32_t lo, int32_t hi);
@@ -55,11 +53,9 @@ public:
type *get_int16_ty(); type *get_int16_ty();
type *get_int32_ty(); type *get_int32_ty();
type *get_int64_ty(); type *get_int64_ty();
type *get_uint8_ty(); type *get_fp8_ty();
type *get_uint16_ty();
type *get_uint32_ty();
type *get_uint64_ty();
type *get_half_ty(); type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty(); type *get_float_ty();
type *get_double_ty(); type *get_double_ty();
// Insert // Insert
@@ -78,7 +74,9 @@ public:
value* create_ret_void(); value* create_ret_void();
value* create_ret(value *ret); value* create_ret(value *ret);
// Cast instructions // Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty);
@@ -100,11 +98,11 @@ public:
value *create_frem(value *lhs, value *rhs); value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs); value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
@@ -155,11 +153,25 @@ public:
value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs); value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes); value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Built-in instruction // Built-in instruction
value *create_get_program_id(unsigned axis); value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis); value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg); value *create_exp(value* arg);
value *create_cos(value* arg); value *create_cos(value* arg);
value *create_sin(value* arg); value *create_sin(value* arg);

View File

@@ -26,7 +26,6 @@ public:
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty; type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types // integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
integer_type uint8_ty, uint16_ty, uint32_ty, uint64_ty;
// Pointer types // Pointer types
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys; std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types // Block types

View File

@@ -1,117 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *fdiv(ir::value *input, ir::value *other, ir::constant_int* ieee_rounding, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache,
const std::string& eviction_policy, int is_volatile, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *xor_sum(ir::value *input, unsigned axis, ir::builder *builder);
// math
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// utilities
static ir::value *globaltimer(ir::builder *builder);
static ir::value *clock(ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -36,7 +36,6 @@ class alloc_const;
class value_constructor { class value_constructor {
typedef std::pair<std::string, basic_block*> val_key_t; typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
private: private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
@@ -57,7 +56,6 @@ public:
// Seal block -- no more predecessors will be added // Seal block -- no more predecessors will be added
void seal_block(basic_block *block); void seal_block(basic_block *block);
// Metadata // Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
private: private:
ir::builder& builder_; ir::builder& builder_;
@@ -66,13 +64,13 @@ private:
std::set<basic_block*> sealed_blocks_; std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_; std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_; std::map<value*, value**> current_phi_;
std::map<std::string, md_pair_t> metadatas_;
}; };
/* Module */ /* Module */
class module { class module {
typedef std::pair<std::string, basic_block*> val_key_t; typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
friend class function; friend class function;
public: public:
@@ -83,13 +81,10 @@ private:
void push_function(function *fn) { functions_.push_back(fn); } void push_function(function *fn) { functions_.push_back(fn); }
public: public:
module(const std::string &name, builder& builder); module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
builder& get_builder(); builder &get_builder() { return builder_; };
// Setters const std::string& get_name() { return name_; };
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Functions // Functions
const functions_list_t &get_function_list() const { return functions_; } const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; } functions_list_t &get_function_list() { return functions_; }
@@ -114,17 +109,19 @@ public:
// Register global // Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; } const std::map<std::string, ir::value*>& globals() const { return globals_; }
// // Metadata
void print(std::ostream &os); void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
private: private:
std::string name_; std::string name_;
builder &builder_; builder &builder_;
functions_list_t functions_; functions_list_t functions_;
symbols_map_t symbols_; symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::vector<ir::alloc_const*> allocs_; std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_; std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
}; };
} }

View File

@@ -16,8 +16,6 @@ class value;
class integer_type; class integer_type;
class constant_int; class constant_int;
enum class signedness { SIGNED, UNSIGNED };
/* Type */ /* Type */
class type { class type {
public: public:
@@ -61,8 +59,6 @@ public:
// type attributes // type attributes
unsigned get_fp_mantissa_width() const; unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const; unsigned get_integer_bitwidth() const;
signedness get_integer_signedness() const;
bool is_integer_signed() const;
unsigned get_tile_bitwidth() const; unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const; unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const; type *get_scalar_ty() const;
@@ -87,9 +83,6 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; } bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth, signedness sn) {
return is_integer_ty() && get_integer_bitwidth() == bitwidth && get_integer_signedness() == sn;
}
bool is_bool_ty() const { return is_integer_ty(1); } bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; } bool is_block_ty() const { return id_ == BlockTyID; }
@@ -118,10 +111,6 @@ public:
static integer_type *get_int32_ty(context &ctx); static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx); static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx); static integer_type *get_int128_ty(context &ctx);
static integer_type *get_uint8_ty(context &ctx);
static integer_type *get_uint16_ty(context &ctx);
static integer_type *get_uint32_ty(context &ctx);
static integer_type *get_uint64_ty(context &ctx);
// repr // repr
std::string tile_repr() const { std::string tile_repr() const {
@@ -148,7 +137,7 @@ public:
case LabelTyID: return "label"; case LabelTyID: return "label";
case MetadataTyID: return "md"; case MetadataTyID: return "md";
case TokenTyID: return "tok"; case TokenTyID: return "tok";
case IntegerTyID: return (is_integer_signed() ? "i" : "u") + std::to_string(get_integer_bitwidth()); case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn"; case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct"; case StructTyID: return "struct";
@@ -171,21 +160,18 @@ class integer_type: public type {
private: private:
// constructors // constructors
integer_type(context &ctx, unsigned bitwidth, signedness sn) integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth), signedness_(sn){ } : type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
public: public:
// accessors // accessors
unsigned get_bitwidth() const { return bitwidth_; } unsigned get_bitwidth() const { return bitwidth_; }
signedness get_signedness() const { return signedness_; }
// factory methods // factory methods
static integer_type* get(context &ctx, unsigned width); static integer_type* get(context &ctx, unsigned width);
private: private:
unsigned bitwidth_; unsigned bitwidth_;
signedness signedness_;
}; };
class composite_type: public type{ class composite_type: public type{

View File

@@ -1,3 +1,4 @@
#include <bits/types/clock_t.h>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
@@ -48,18 +49,12 @@ void builder::set_insert_point(basic_block *block){
value *builder::get_int1(bool val) value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); } { return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val) value *builder::get_int32(uint32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);} { return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_uint32(uint32_t val) value *builder::get_int64(uint64_t val)
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
value *builder::get_int64(int64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);} { return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_uint64(uint64_t val)
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
value *builder::get_float16(float val) value *builder::get_float16(float val)
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); } { return constant_fp::get(type::get_fp16_ty(ctx_), val); }
@@ -90,21 +85,15 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty() type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); } { return type::get_int64_ty(ctx_); }
type *builder::get_uint8_ty() type *builder::get_fp8_ty()
{ return type::get_uint8_ty(ctx_); } { return type::get_fp8_ty(ctx_); }
type *builder::get_uint16_ty()
{ return type::get_uint16_ty(ctx_); }
type *builder::get_uint32_ty()
{ return type::get_uint32_ty(ctx_); }
type *builder::get_uint64_ty()
{ return type::get_uint64_ty(ctx_); }
type *builder::get_half_ty() type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); } { return type::get_fp16_ty(ctx_); }
type *builder::get_bf16_ty()
{ return type::get_bf16_ty(ctx_); }
type *builder::get_float_ty() type *builder::get_float_ty()
{ return type::get_fp32_ty(ctx_); } { return type::get_fp32_ty(ctx_); }
@@ -140,6 +129,8 @@ value *builder::create_ret(value* val) {
return create_cast(OPCODE, src, dst_ty);\ return create_cast(OPCODE, src, dst_ty);\
} }
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
@@ -358,6 +349,37 @@ value *builder::create_downcast(value *arg) {
return insert(downcast_inst::create(arg)); return insert(downcast_inst::create(arg));
} }
//
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
return create_atomic_rmw(OPCODE, ptr, val, mask);\
}
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
// Utilities
value *builder::create_clock() {
return insert(clock_inst::create(ctx_));
}
value *builder::create_globaltimer() {
return insert(globaltimer_inst::create(ctx_));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// built-in instructions // built-in instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -374,9 +396,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
return insert(atomic_cas_inst::create(ptr, cmp, val)); return insert(atomic_cas_inst::create(ptr, cmp, val));
} }
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
value *builder::create_exp(value *arg){ value *builder::create_exp(value *arg){
return insert(exp_inst::create(arg)); return insert(exp_inst::create(arg));

View File

@@ -19,18 +19,12 @@ context_impl::context_impl(context &ctx)
fp32_ty(ctx, type::FP32TyID), fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID), fp64_ty(ctx, type::FP64TyID),
// integers // integers
int1_ty(ctx, 1, signedness::SIGNED), int1_ty(ctx, 1),
int8_ty(ctx, 8, signedness::SIGNED), int8_ty(ctx, 8),
int16_ty(ctx, 16, signedness::SIGNED), int16_ty(ctx, 16),
int32_ty(ctx, 32, signedness::SIGNED), int32_ty(ctx, 32),
int64_ty(ctx, 64, signedness::SIGNED), int64_ty(ctx, 64),
int128_ty(ctx, 128, signedness::SIGNED), int128_ty(ctx, 128) {}
uint8_ty(ctx, 8, signedness::UNSIGNED),
uint16_ty(ctx, 16, signedness::UNSIGNED),
uint32_ty(ctx, 32, signedness::UNSIGNED),
uint64_ty(ctx, 64, signedness::UNSIGNED){
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// context // context

View File

@@ -1,895 +0,0 @@
#include "triton/ir/dispatch.h"
namespace triton {
namespace ir {
[[ noreturn ]] void throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
}
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
auto a_sn = a_ty->get_integer_signedness();
auto b_sn = b_ty->get_integer_signedness();
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if (a_sn == b_sn) {
return a_rank > b_rank ? a_ty : b_ty;
} else if (a_sn == signedness::UNSIGNED) {
return a_rank >= b_rank ? a_ty : b_ty;
} else if (b_sn == signedness::UNSIGNED) {
return b_rank >= a_rank ? b_ty : a_ty;
} else {
throw_unreachable("integer_promote");
}
}
enum class DivOrMod { NO, YES };
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly converted to half
// unless we're doing / or %, which do not exist natively in PTX for fp16.
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
if (div_or_mod == DivOrMod::YES) {
return type::get_fp32_ty(ctx);
} else {
return type::get_fp16_ty(ctx);
}
}
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("computation_type");
// 4 ) both operands are integer and undergo
// integer promotion
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
return integer_promote(a_ty, b_ty);
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
}
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
if(type_a->is_pointer_ty()){
if(!allow_ptr_a)
throw_incompatible_types(type_a, type_b);
// T* + U* with T != U
if(type_b->is_pointer_ty() && (type_a != type_b))
throw_incompatible_types(type_a, type_b);
// T* + float
if(type_b->is_floating_point_ty())
throw_incompatible_types(type_a, type_b);
}
}
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
// implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
}
}
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, true);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// offset + ptr
// ptr + offset
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
std::swap(input, other);
if (input_scalar_ty->is_pointer_ty())
return builder->create_gep(input, {other});
// float + float
else if (input_scalar_ty->is_floating_point_ty())
return builder->create_fadd(input, other);
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, false);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// ptr - offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(input, {dispatch::minus(other, builder)});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(input, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(input, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
other = cast(other, input_scalar_ty, builder);
// int / float
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
input = cast(input, other_scalar_ty, builder);
// int / int (cast to float32)
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
input = cast(input, builder->get_float_ty(), builder);
other = cast(other, builder->get_float_ty(), builder);
}
// float / float (cast to highest exponent type)
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
other = cast(other, input_scalar_ty, builder);
else
input = cast(input, other_scalar_ty, builder);
}
// unreachable
else
throw_unreachable("div");
return builder->create_fdiv(input, other);
}
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
if (ret_ty->is_integer_signed()) {
return builder->create_sdiv(input, other);
} else {
return builder->create_udiv(input, other);
}
}
throw_unreachable("floordiv");
}
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
throw semantic_error("both operands of fdiv must have floating point scalar type");
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
ir::value* ret = builder->create_fdiv(input, other);
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
return ret;
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
if (scalar_ty->is_integer_signed()) {
return builder->create_srem(input, other);
} else {
return builder->create_urem(input, other);
}
}
throw_unreachable("mod");
}
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
if (ret_sca_ty != input_sca_ty)
input = dispatch::cast(input, ret_sca_ty, builder);
if (ret_sca_ty != other_sca_ty)
other = dispatch::cast(other, ret_sca_ty, builder);
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_and(input, other);
}
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_or(input, other);
}
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_xor(input, other);
}
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_lshr(input, other);
}
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_shl(input, other);
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
return input;
}
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty())
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
return dispatch::sub(_0, input, builder);
}
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
return dispatch::xor_(input, _1, builder);
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGT(input, other);
} else {
return builder->create_icmpUGT(input, other);
}
}
throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGE(input, other);
} else {
return builder->create_icmpUGE(input, other);
}
}
throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLT(input, other);
} else {
return builder->create_icmpULT(input, other);
}
}
throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLE(input, other);
} else {
return builder->create_icmpULE(input, other);
}
}
throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpUNE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
}
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
unsigned numel = 1;
for(unsigned s: dst_shape) numel *= s;
if(input->get_type()->get_tile_num_elements() != numel)
throw semantic_error("cannot reshape block of different shape");
return builder->create_reshape(input, dst_shape);
}
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
return builder->create_cat(lhs, rhs);
}
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
if(shape == src_shape)
return input;
return builder->create_broadcast(input, shape);
}
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
return cast(input, dst_ty, builder);
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
if( src_bits!= dst_bits)
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
"to data-type of size " + std::to_string(dst_bits));
return builder->create_cast(ir::BitCast, input, dst_ty);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
//
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
}
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
return builder->create_int_cast(input, dst_ty, sign_extend);
}
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
return builder->create_fp_to_ui(input, dst_ty);
else
return builder->create_fp_to_si(input, dst_ty);
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
}
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){
int bitwidth = dst_sca_ty->get_integer_bitwidth();
if(bitwidth == 64)
return builder->create_cast(ir::PtrToInt, input, dst_ty);
if(bitwidth == 1)
return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder),
builder->get_int64(0),
builder);
}
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::IntToPtr, input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, builder->get_int64_ty(), builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
if(other)
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
}
if(other)
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cache modifier
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
// eviction policy
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
if(!eviction_policy.empty()){
if (eviction_policy == "evict_last")
eviction = load_inst::EVICT_LAST;
else if(eviction_policy == "evict_first")
eviction = load_inst::EVICT_FIRST;
else
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr, cache, eviction, is_volatile);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty())
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
val = dispatch::cast(val, elt_ty, builder);
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
throw semantic_error("Mask must have boolean scalar type");
return builder->create_masked_store(ptr, val, mask);
}
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
return builder->create_atomic_cas(ptr, cmp, val);
}
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(val){
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
}
}
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
}
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
}
}
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_min for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
}
}
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
return builder->create_atomic_rmw(op, ptr, val, mask);
}
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
}
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
}
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
}
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = nullptr;
if (lhs->get_type()->is_int_or_tileint_ty())
_0 = builder->get_int32(0);
else
_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
bool _allow_tf32 = allow_tf32->get_value() != 0;
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
if(condition->get_type()->is_block_ty()){
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
ir::type* x_ty = x->get_type()->get_scalar_ty();
ir::type* y_ty = y->get_type()->get_scalar_ty();
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
x = dispatch::cast(x, ty, builder);
y = dispatch::cast(y, ty, builder);
return builder->create_select(condition, x, y);
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
}
ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (!scalar_ty->is_integer_ty())
throw semantic_error("xor_sum only supported for integers");
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
binary_op_type_checking(x, y, builder);
return builder->insert(umulhi_inst::create(x, y));
}
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
return builder->create_exp(x);
}
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
return builder->create_log(x);
}
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
return builder->create_cos(x);
}
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
return builder->create_sin(x);
}
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}
ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));
}
//===----------------------------------------------------------------------===//
// Control FLow
//===----------------------------------------------------------------------===//
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("multiple_of");
i->set_metadata(ir::metadata::multiple_of, value);
return i;
}
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("max_contiguous");
i->set_metadata(ir::metadata::max_contiguous, value);
return i;
}
ir::value *dispatch::debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
}
}

View File

@@ -9,154 +9,10 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
/* */
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
sealed_blocks_.insert(nullptr);
}
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
// value->set_name(name);
}
void value_constructor::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
builder_.set_insert_point(insert);
}
ir::phi_node *res = builder_.create_phi(ty, num_values);
if(insert != block->end())
builder_.set_insert_point(block);
return res;
}
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
[phi](ir::value* op){ return op != phi && op; });
// non-trivial
if(non_self_ref.size() != 1)
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::vector<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
return same;
}
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
return phi;
ir::basic_block *block = phi->get_parent();
for(ir::basic_block *pred: block->get_predecessors()){
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return phi;
}
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
auto preds = block->get_predecessors();
ir::type *ty = types_.at(name);
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
else if(preds.size() <= 1){
bool has_pred = preds.size();
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
// std::cout << values_.size() << std::endl;
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
if(values_.find(key) != values_.end()){
return values_.at(key);
}
ir::value *result = get_value_recursive(name, block);
builder_.set_insert_point(save_block);
if(save_pt != save_block->end())
builder_.set_insert_point(save_pt);
return result;
}
ir::value *value_constructor::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
void value_constructor::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
if(get_value(x.first) == x.second)
set_value(x.first, try_remove_trivial_phis(x.second));
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
}
/* Module */
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
}
void module::reset_ret_ty(const std::string& name, type* ty) { void module::reset_ret_ty(const std::string& name, type* ty) {
get_function(name)->get_fn_type()->reset_ret_ty(ty); get_function(name)->get_fn_type()->reset_ret_ty(ty);
} }
ir::builder& module::get_builder() {
return builder_;
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
const std::string& module::get_name() {
return name_;
}
/* functions */ /* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name]; function *&fn = (function*&)symbols_[name];

View File

@@ -36,16 +36,6 @@ unsigned type::get_primitive_size_in_bits() const {
unsigned type::get_integer_bitwidth() const unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
signedness type::get_integer_signedness() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
bool type::is_integer_signed() const {
if (id_ != IntegerTyID) {
throw std::logic_error("type is " + repr() + ", not integer");
}
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
}
unsigned type::get_tile_bitwidth() const unsigned type::get_tile_bitwidth() const
{ return ((block_type*)(this))->get_bitwidth(); } { return ((block_type*)(this))->get_bitwidth(); }
@@ -145,10 +135,6 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }

View File

@@ -3,7 +3,6 @@
#include "triton/driver/error.h" #include "triton/driver/error.h"
#include "triton/driver/llvm.h" #include "triton/driver/llvm.h"
#include "triton/ir/builder.h" #include "triton/ir/builder.h"
#include "triton/ir/dispatch.h"
#include "triton/ir/enums.h" #include "triton/ir/enums.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
@@ -12,10 +11,12 @@
#include <pybind11/buffer_info.h> #include <pybind11/buffer_info.h>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl_bind.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "Python.h" #include "Python.h"
#include <regex> #include <regex>
#include <sstream> #include <sstream>
#include <stdexcept>
#include <string> #include <string>
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
@@ -541,87 +542,6 @@ void init_triton_codegen(py::module &&m) {
}, py::return_value_policy::take_ownership); }, py::return_value_policy::take_ownership);
} }
/*****************************************************************************/
/* User-facing language features */
/*****************************************************************************/
void init_triton_frontend(py::module &&m) {
using ret = py::return_value_policy;
// programming model
m.def("program_id", &ir::dispatch::program_id, ret::reference);
m.def("num_programs", &ir::dispatch::num_programs, ret::reference);
// binary
m.def("add", &ir::dispatch::add, ret::reference);
m.def("sub", &ir::dispatch::sub, ret::reference);
m.def("mul", &ir::dispatch::mul, ret::reference);
m.def("truediv", &ir::dispatch::truediv, ret::reference);
m.def("floordiv", &ir::dispatch::floordiv, ret::reference);
m.def("fdiv", &ir::dispatch::fdiv, ret::reference);
m.def("mod", &ir::dispatch::mod, ret::reference);
m.def("and_", &ir::dispatch::and_, ret::reference);
m.def("or_", &ir::dispatch::or_, ret::reference);
m.def("xor_", &ir::dispatch::xor_, ret::reference);
m.def("lshr", &ir::dispatch::lshr, ret::reference);
m.def("shl", &ir::dispatch::shl, ret::reference);
// unary
m.def("plus", &ir::dispatch::plus, ret::reference);
m.def("minus", &ir::dispatch::minus, ret::reference);
m.def("invert", &ir::dispatch::invert, ret::reference);
// comparison
m.def("greater_than", &ir::dispatch::greater_than, ret::reference);
m.def("greater_equal", &ir::dispatch::greater_equal, ret::reference);
m.def("less_than", &ir::dispatch::less_than, ret::reference);
m.def("less_equal", &ir::dispatch::less_equal, ret::reference);
m.def("equal", &ir::dispatch::equal, ret::reference);
m.def("not_equal", &ir::dispatch::not_equal, ret::reference);
// block creation
m.def("arange", &ir::dispatch::arange, ret::reference);
m.def("zeros", &ir::dispatch::zeros, ret::reference);
// type manipuatation
m.def("cat", &ir::dispatch::cat, ret::reference);
m.def("reshape", &ir::dispatch::reshape, ret::reference);
typedef std::tuple<ir::value *, ir::value *> (*broadcast_ty)(ir::value *, ir::value *, ir::builder *);
typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *);
m.def("broadcast", (broadcast_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("broadcast_to", (broadcast_to_ty)(&ir::dispatch::broadcast), ret::reference);
m.def("bitcast", &ir::dispatch::bitcast, ret::reference);
m.def("cast", &ir::dispatch::cast, ret::reference);
// memory
m.def("load", &ir::dispatch::load, ret::reference);
m.def("store", &ir::dispatch::store, ret::reference);
m.def("atomic_cas", &ir::dispatch::atomic_cas, ret::reference);
m.def("atomic_xchg", &ir::dispatch::atomic_xchg, ret::reference);
m.def("atomic_add", &ir::dispatch::atomic_add, ret::reference);
m.def("atomic_max", &ir::dispatch::atomic_max, ret::reference);
m.def("atomic_min", &ir::dispatch::atomic_min, ret::reference);
m.def("atomic_and", &ir::dispatch::atomic_and, ret::reference);
m.def("atomic_or", &ir::dispatch::atomic_or, ret::reference);
m.def("atomic_xor", &ir::dispatch::atomic_xor, ret::reference);
// linear algebra
m.def("dot", &ir::dispatch::dot, ret::reference);
// indexing
m.def("where", &ir::dispatch::where, ret::reference);
// reduction
m.def("min", &ir::dispatch::min, ret::reference);
m.def("max", &ir::dispatch::max, ret::reference);
m.def("sum", &ir::dispatch::sum, ret::reference);
m.def("xor_sum", &ir::dispatch::xor_sum, ret::reference);
// math
m.def("umulhi", &ir::dispatch::umulhi, ret::reference);
m.def("exp", &ir::dispatch::exp, ret::reference);
m.def("log", &ir::dispatch::log, ret::reference);
m.def("cos", &ir::dispatch::cos, ret::reference);
m.def("sin", &ir::dispatch::sin, ret::reference);
m.def("sqrt", &ir::dispatch::sqrt, ret::reference);
// utilities
m.def("clock", &ir::dispatch::clock, ret::reference);
m.def("globaltimer", &ir::dispatch::globaltimer, ret::reference);
// internal (debugging only)
m.def("multiple_of", &ir::dispatch::multiple_of, ret::reference);
m.def("max_contiguous", &ir::dispatch::max_contiguous, ret::reference);
m.def("debug_barrier", &ir::dispatch::debug_barrier, ret::reference);
}
/*****************************************************************************/ /*****************************************************************************/
/* Python bindings for triton::ir */ /* Python bindings for triton::ir */
@@ -631,16 +551,86 @@ void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy; using ret = py::return_value_policy;
using namespace pybind11::literals; using namespace pybind11::literals;
py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
.value("NONE", ir::load_inst::NONE)
.value("CA", ir::load_inst::CA)
.value("CG", ir::load_inst::CG)
.export_values();
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
.value("NORMAL", ir::load_inst::NORMAL)
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
.export_values();
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
.value("ADD", ir::reduce_inst::ADD)
.value("FADD", ir::reduce_inst::FADD)
.value("MIN", ir::reduce_inst::MIN)
.value("MAX", ir::reduce_inst::MAX)
.value("FMIN", ir::reduce_inst::FMIN)
.value("FMAX", ir::reduce_inst::FMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
.value("ADD", ir::atomic_rmw_op_t::Add)
.value("FADD", ir::atomic_rmw_op_t::FAdd)
.value("AND", ir::atomic_rmw_op_t::And)
.value("OR", ir::atomic_rmw_op_t::Or)
.value("XOR", ir::atomic_rmw_op_t::Xor)
.value("XCHG", ir::atomic_rmw_op_t::Xchg)
.value("MAX", ir::atomic_rmw_op_t::Max)
.value("MIN", ir::atomic_rmw_op_t::Min)
.value("UMIN", ir::atomic_rmw_op_t::UMin)
.value("UMAX", ir::atomic_rmw_op_t::UMax);
py::class_<ir::context>(m, "context") py::class_<ir::context>(m, "context")
.def(py::init<>()); .def(py::init<>());
auto value = py::class_<ir::value>(m, "value"); py::class_<ir::value>(m, "value")
value.def_property("name", &ir::value::get_name, &ir::value::set_name); .def("multiple_of", [](ir::value *self, int val) {
value.def_property_readonly("type", &ir::value::get_type); if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
instr->set_metadata(ir::metadata::multiple_of, val);
} else
throw std::runtime_error("multiple_of");
})
.def("max_contiguous", [](ir::value *self, int val) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
instr->set_metadata(ir::metadata::max_contiguous, val);
} else
throw std::runtime_error("max_contiguous");
})
.def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
instr->set_fdiv_ieee_rounding(val);
else
throw std::runtime_error("set_fdiv_ieee_rounding");
})
.def("is_phi", [](ir::value *self) {
if (auto *pn = dynamic_cast<ir::phi_node*>(self))
return true;
return false;
})
.def("ops", [](ir::value *self) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
return instr->ops();
}
throw std::runtime_error("cannot use ops()");
})
.def("replace_all_uses_with", &ir::value::replace_all_uses_with)
.def("erase_from_parent", [](ir::value *self) {
if (auto *instr = dynamic_cast<ir::instruction*>(self))
return instr->erase_from_parent();
throw std::runtime_error("cannot use erase_from_parent");
})
.def_property("name", &ir::value::get_name, &ir::value::set_name)
.def_property_readonly("type", &ir::value::get_type);
py::class_<ir::user, ir::value>(m, "user"); py::class_<ir::user, ir::value>(m, "user");
py::class_<ir::constant, ir::user>(m, "constant"); py::class_<ir::constant, ir::user>(m, "constant")
.def("get_null_value", &ir::constant::get_null_value, ret::reference)
.def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
py::class_<ir::undef_value, ir::constant>(m, "undef") py::class_<ir::undef_value, ir::constant>(m, "undef")
.def("get", &ir::undef_value::get, ret::reference); .def("get", &ir::undef_value::get, ret::reference);
@@ -651,18 +641,17 @@ void init_triton_ir(py::module &&m) {
.def("__bool__", [](ir::constant_int *self) { return self->get_value(); }); .def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
py::class_<ir::constant_fp, ir::constant>(m, "constant_float") py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value); .def_property_readonly("value", &ir::constant_fp::get_value)
.def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
py::class_<ir::instruction, ir::user>(m, "instruction"); py::class_<ir::instruction, ir::user>(m, "instruction")
py::class_<ir::phi_node, ir::user>(m, "phi_node"); .def("get_parent", [](ir::instruction *self) {
return self->get_parent();
}, ret::reference);
py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
.def("add_incoming", &ir::phi_node::add_incoming);
py::class_<ir::type>(m, "type") py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("get_int_width", &ir::type::get_integer_bitwidth)
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference) .def("make_ptr", &ir::pointer_type::get, ret::reference)
.def("make_function", &ir::function_type::get, ret::reference) .def("make_function", &ir::function_type::get, ret::reference)
.def("make_block", &ir::block_type::get, ret::reference) .def("make_block", &ir::block_type::get, ret::reference)
@@ -677,35 +666,39 @@ void init_triton_ir(py::module &&m) {
.def("get_int16", &ir::type::get_int16_ty, ret::reference) .def("get_int16", &ir::type::get_int16_ty, ret::reference)
.def("get_int32", &ir::type::get_int32_ty, ret::reference) .def("get_int32", &ir::type::get_int32_ty, ret::reference)
.def("get_int64", &ir::type::get_int64_ty, ret::reference) .def("get_int64", &ir::type::get_int64_ty, ret::reference)
.def("get_uint8", &ir::type::get_uint8_ty, ret::reference) .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference)
.def("get_uint16", &ir::type::get_uint16_ty, ret::reference)
.def("get_uint32", &ir::type::get_uint32_ty, ret::reference)
.def("get_uint64", &ir::type::get_uint64_ty, ret::reference)
.def("get_block_shapes", &ir::type::get_block_shapes)
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("is_struct", &ir::type::is_struct_ty)
.def("is_void", &ir::type::is_void_ty) .def("is_void", &ir::type::is_void_ty)
.def("is_bool", &ir::type::is_bool_ty)
.def("is_fp8", &ir::type::is_fp8_ty) .def("is_fp8", &ir::type::is_fp8_ty)
.def("is_fp16", &ir::type::is_fp16_ty) .def("is_fp16", &ir::type::is_fp16_ty)
.def("is_bf16", &ir::type::is_bf16_ty) .def("is_bf16", &ir::type::is_bf16_ty)
.def("is_fp32", &ir::type::is_fp32_ty) .def("is_fp32", &ir::type::is_fp32_ty)
.def("is_fp64", &ir::type::is_fp64_ty) .def("is_fp64", &ir::type::is_fp64_ty)
.def("is_int1", [](ir::type *self) { return self->is_integer_ty(1, ir::signedness::SIGNED); }) .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
.def("is_int8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::SIGNED); }) .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
.def("is_int16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::SIGNED); }) .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
.def("is_int32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::SIGNED); }) .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
.def("is_int64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::SIGNED); }) .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
.def("is_uint8", [](ir::type *self) { return self->is_integer_ty(8, ir::signedness::UNSIGNED); }) .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty)
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("is_struct", &ir::type::is_struct_ty)
.def("repr", &ir::type::repr) .def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width) .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
.def_property_readonly("scalar", &ir::type::get_scalar_ty) .def_property_readonly("scalar", &ir::type::get_scalar_ty)
.def_property_readonly("context", &ir::type::get_context, ret::reference); .def_property_readonly("context", &ir::type::get_context, ret::reference)
.def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
.def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
py::class_<ir::pointer_type, ir::type>(m, "pointer_type") py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference); .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type") py::class_<ir::function_type, ir::type>(m, "function_type")
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty) .def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
@@ -723,21 +716,20 @@ void init_triton_ir(py::module &&m) {
.def("get", &ir::struct_type::get, ret::reference) .def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types); .def_property_readonly("num_types", &ir::struct_type::get_num_types);
py::class_<ir::value_constructor>(m, "value_constructor")
.def(py::init<ir::builder&>())
.def("seal_block", &ir::value_constructor::seal_block)
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
.def("set_type", &ir::value_constructor::set_type)
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
.def("get_values", &ir::value_constructor::get_values, ret::reference)
.def("set_values", &ir::value_constructor::set_values);
py::class_<ir::module>(m, "module") py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>()) .def(py::init<std::string, ir::builder &>())
.def("has_function", &ir::module::has_function) .def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference) .def("get_function", &ir::module::get_function, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("reset_ret_ty", &ir::module::reset_ret_ty) .def("reset_ret_ty", &ir::module::reset_ret_ty)
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
const auto metadatas = self->get_metadatas();
auto it = metadatas.find(name);
if (it != metadatas.end())
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
instr->set_metadata(it->second.first, it->second.second);
}
})
.def_property_readonly("builder", &ir::module::get_builder, ret::reference); .def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t; using eattr = ir::attribute_kind_t;
@@ -768,6 +760,13 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::basic_block, ir::value>(m, "basic_block") py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr) .def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
ir::basic_block::iterator it = self->get_first_non_phi();
if (it == self->end())
return nullptr;
return *it;
}, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder::iterator>(m, "bb_iterator"); py::class_<ir::builder::iterator>(m, "bb_iterator");
@@ -783,22 +782,168 @@ void init_triton_ir(py::module &&m) {
.def("cond_br", &ir::builder::create_cond_br, ret::reference) .def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference) .def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("ret", &ir::builder::create_ret, ret::reference) .def("ret", &ir::builder::create_ret, ret::reference)
.def("get_insert_point", &ir::builder::get_insert_point) // insertion block/point, insert points are represented as (*bb, *instr)
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference) .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
.def("get_insert_point", [](ir::builder *self) {
ir::basic_block *bb = self->get_insert_block();
ir::basic_block::iterator it = self->get_insert_point();
ir::instruction *instr = it == bb->end() ? nullptr : *it;
return std::make_pair(bb, instr);
}, ret::reference)
.def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
ir::basic_block *bb = pt.first;
ir::instruction *instr = pt.second;
if (instr) {
if (bb != instr->get_parent())
throw std::runtime_error("invalid insertion point, instr not in bb");
self->set_insert_point(instr);
} else {
assert(bb);
self->set_insert_point(bb);
}
})
// Constants
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
.def("get_uint32", &ir::builder::get_int32, ret::reference)
.def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
.def("get_uint64", &ir::builder::get_int64, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference)
// Types
.def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
.def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
.def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
.def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
.def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
.def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
.def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
.def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
.def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
.def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
.def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
// terminator instructions
.def("create_br", &ir::builder::create_br, ret::reference)
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
// Cast instructions
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
.def("create_cast", &ir::builder::create_cast, ret::reference)
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
.def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
.def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
.def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
.def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
.def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
// phi
.def("create_phi", &ir::builder::create_phi, ret::reference)
// Binary instructions
.def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
.def("create_fmul", &ir::builder::create_fmul, ret::reference)
.def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
.def("create_frem", &ir::builder::create_frem, ret::reference)
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
.def("create_srem", &ir::builder::create_srem, ret::reference)
.def("create_urem", &ir::builder::create_urem, ret::reference)
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sub", &ir::builder::create_sub, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_shl", &ir::builder::create_shl, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
// GEP
.def("create_gep", &ir::builder::create_gep, ret::reference)
// Comparison (int)
.def("create_icmp", &ir::builder::create_icmp, ret::reference)
.def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
.def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
.def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
.def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
.def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
.def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
.def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
.def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
.def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
.def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
// Comparison (float)
.def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
.def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
.def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
.def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
.def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
.def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
.def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
.def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
.def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
.def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
.def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
.def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
.def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
// Logical
.def("create_and", &ir::builder::create_and, ret::reference)
.def("create_xor", &ir::builder::create_xor, ret::reference)
.def("create_or", &ir::builder::create_or, ret::reference)
// Input/Output
.def("create_load", &ir::builder::create_load, ret::reference)
.def("create_store", &ir::builder::create_store, ret::reference)
.def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
.def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
// Block instruction
.def("create_splat", &ir::builder::create_splat, ret::reference)
.def("create_reshape", &ir::builder::create_reshape, ret::reference)
.def("create_cat", &ir::builder::create_cat, ret::reference)
.def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
// atomic
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
// Utilities
.def("create_clock", &ir::builder::create_clock, ret::reference)
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
// Built-in instruction
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
.def("create_exp", &ir::builder::create_exp, ret::reference)
.def("create_cos", &ir::builder::create_cos, ret::reference)
.def("create_sin", &ir::builder::create_sin, ret::reference)
.def("create_log", &ir::builder::create_log, ret::reference)
.def("create_dot", &ir::builder::create_dot, ret::reference)
.def("create_trans", &ir::builder::create_trans, ret::reference)
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
.def("create_select", &ir::builder::create_select, ret::reference)
// struct // struct
.def("insert_value", &ir::builder::create_insert_value, ret::reference) .def("insert_value", &ir::builder::create_insert_value, ret::reference)
.def("extract_value", &ir::builder::create_extract_value, ret::reference) .def("extract_value", &ir::builder::create_extract_value, ret::reference)
// constants // Intrinsics
.def("get_int1", &ir::builder::get_int1, ret::reference) // These have no place in the IR, and hopefully they can be removed at some point
.def("get_int32", &ir::builder::get_int32, ret::reference) .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
.def("get_int64", &ir::builder::get_int64, ret::reference) .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
.def("get_uint32", &ir::builder::get_uint32, ret::reference) .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
.def("get_uint64", &ir::builder::get_uint64, ret::reference) .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
.def("get_float16", &ir::builder::get_float16, ret::reference) .def("create_barrier", &ir::builder::create_barrier, ret::reference)
.def("get_float32", &ir::builder::get_float32, ret::reference) .def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
.def("get_range", &ir::builder::get_range, ret::reference); .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
} }
void init_triton(py::module &m) { void init_triton(py::module &m) {
@@ -806,5 +951,4 @@ void init_triton(py::module &m) {
init_triton_codegen(std::move(subm.def_submodule("code_gen"))); init_triton_codegen(std::move(subm.def_submodule("code_gen")));
init_triton_runtime(std::move(subm.def_submodule("runtime"))); init_triton_runtime(std::move(subm.def_submodule("runtime")));
init_triton_ir(std::move(subm.def_submodule("ir"))); init_triton_ir(std::move(subm.def_submodule("ir")));
init_triton_frontend(std::move(subm.def_submodule("frontend")));
} }

View File

@@ -37,7 +37,7 @@ matmul_data = {
(256, 256, 256): {'float16': 0.027}, (256, 256, 256): {'float16': 0.027},
(512, 512, 512): {'float16': 0.158}, (512, 512, 512): {'float16': 0.158},
(1024, 1024, 1024): {'float16': 0.466}, (1024, 1024, 1024): {'float16': 0.466},
(2048, 2048, 2048): {'float16': 0.680}, (2048, 2048, 2048): {'float16': 0.695},
(4096, 4096, 4096): {'float16': 0.831}, (4096, 4096, 4096): {'float16': 0.831},
(8192, 8192, 8192): {'float16': 0.849}, (8192, 8192, 8192): {'float16': 0.849},
# tall-skinny # tall-skinny

View File

@@ -1,5 +1,4 @@
# flake8: noqa: F821,F841 # flake8: noqa: F821,F841
import copy
import itertools import itertools
import re import re
from typing import Optional, Union from typing import Optional, Union
@@ -12,7 +11,7 @@ from numpy.random import RandomState
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
import triton.language as tl import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret from triton.code_gen import JITFunction, TensorWrapper, reinterpret
int_dtypes = ['int8', 'int16', 'int32', 'int64'] int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -993,11 +992,17 @@ def test_noop(device='cuda'):
@pytest.mark.parametrize("value, value_type", [ @pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, None), (-2**31, 'i32'), (2**31 - 1, 'i32'), (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'), (2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
]) ])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None: def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
spec_type = None
def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["arg_types"][0][1]
JITFunction.cache_hook = cache_hook
@triton.jit @triton.jit
def kernel(VALUE, X): def kernel(VALUE, X):
@@ -1006,11 +1011,8 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
x = torch.tensor([3.14159], device='cuda') x = torch.tensor([3.14159], device='cuda')
pgm = kernel[(1, )](value, x) pgm = kernel[(1, )](value, x)
# Parse out the type of the 'VALUE' parameter from the Triton IR. JITFunction.cache_hook = None
triton_ir = pgm.asm['ttir'] assert spec_type == value_type
ir_value_match = re.match(r'\s*def void (\w+)\((\w+) VALUE ', triton_ir)
ir_value_type = None if ir_value_match is None else ir_value_match.group(2)
assert ir_value_type == value_type
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -1045,13 +1047,13 @@ def stub(X, alpha, grid_0, grid_1, grid_2):
tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2]) tl.launch(mult, [X, alpha], [grid_0, grid_1, grid_2])
def test_dyn_par(cond=True, device='cuda'): # def test_dyn_par(cond=True, device='cuda'):
n_pids = 10 # n_pids = 10
# pids = torch.arange(n_pids, device=device) # # pids = torch.arange(n_pids, device=device)
# alpha = 2.0 # # alpha = 2.0
# x_ref = pids * alpha # # x_ref = pids * alpha
x_tri = torch.full((10,), fill_value=-1., device=device) # x_tri = torch.full((10,), fill_value=-1., device=device)
# cond = torch.tensor([cond], device=device) # # cond = torch.tensor([cond], device=device)
stub[(1,)](x_tri, 3.14, n_pids, 1, 1) # stub[(1,)](x_tri, 3.14, n_pids, 1, 1)
print(x_tri) # print(x_tri)
# triton.testing.assert_almost_equal(x_ref, x_tri) # # triton.testing.assert_almost_equal(x_ref, x_tri)

View File

@@ -1,4 +1,5 @@
import os import os
import re
import shutil import shutil
import pytest import pytest
@@ -102,3 +103,30 @@ def test_specialize(mode):
for i in [1, 2, 4, 8, 16, 32]: for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512) function[(1,)](x, i, BLOCK=512)
assert counter == target assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'),
(2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'),
(2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs['key'].split('-')
triton.code_gen.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.code_gen.JITFunction.cache_hook = None
cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1])
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type

View File

@@ -6,7 +6,8 @@ __version__ = '2.0.0'
# or pybind11 shows `munmap_chunk(): invalid pointer` # or pybind11 shows `munmap_chunk(): invalid pointer`
import torch import torch
# submodules # submodules
from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, JITFunction, Config, Autotuner, reinterpret from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \
JITFunction, Config, Autotuner, reinterpret
from . import language from . import language
from . import code_gen from . import code_gen
from . import testing from . import testing

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import ast import ast
import builtins import builtins
import functools import functools
@@ -11,7 +13,7 @@ import tempfile
import textwrap import textwrap
import time import time
import warnings import warnings
from typing import Dict from typing import Dict, Set, Tuple, Union
import torch import torch
from filelock import FileLock from filelock import FileLock
@@ -21,26 +23,26 @@ import triton._C.libtriton.triton as _triton
from .tools.disasm import extract from .tools.disasm import extract
def mangle_ty(type): def mangle_ty(ty):
if type.is_ptr(): if ty.is_ptr():
return 'P' + mangle_ty(type.element) return 'P' + mangle_ty(ty.element_ty)
if type.is_int(): if ty.is_int():
return 'i' + str(type.get_int_width()) return 'i' + str(ty.int_bitwidth)
if type.is_fp8(): if ty.is_fp8():
return 'fp8' return 'fp8'
if type.is_fp16(): if ty.is_fp16():
return 'fp16' return 'fp16'
if type.is_bf16(): if ty.is_bf16():
return 'bf16' return 'bf16'
if type.is_fp32(): if ty.is_fp32():
return 'fp32' return 'fp32'
if type.is_fp64(): if ty.is_fp64():
return 'fp64' return 'fp64'
if type.is_void(): if ty.is_void():
return 'V' return 'V'
if type.is_block(): if ty.is_block():
elt = mangle_ty(type.scalar) elt = mangle_ty(ty.scalar)
shape = '_'.join(map(str, type.shape)) shape = '_'.join(map(str, ty.shape))
return f'{elt}S{shape}S' return f'{elt}S{shape}S'
assert False, "Unsupport type" assert False, "Unsupport type"
@@ -56,8 +58,38 @@ def mangle_fn(name, arg_tys, constants):
return ret return ret
class CodeGenerator(ast.NodeVisitor): def is_triton_tensor(value):
return isinstance(value, triton.language.tensor)
class ValueConstructor:
def __init__(self, module, builder, gscope) -> None:
self.gscope = gscope
self.lscope = dict()
self.builder = builder
self.module = module
# [name, bb] => triton.language.tensor
self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {}
# bb => {name => phi}
self.incomplete_phis = {}
self.sealed_blocks: Set[_triton.ir.basic_block] = set()
#
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'isinstance': isinstance,
'getattr': getattr,
}
def get_value(self, name): def get_value(self, name):
''' This function:
1. make sure `name` is defined
2. if `name` is triton.language.tensor, get stored tensor by calling
`self._get_tensor()`
'''
# search node.id in local scope # search node.id in local scope
ret = None ret = None
if name in self.lscope: if name in self.lscope:
@@ -70,21 +102,123 @@ class CodeGenerator(ast.NodeVisitor):
ret = self.builtins[name] ret = self.builtins[name]
else: else:
raise ValueError(f'{name} is not defined') raise ValueError(f'{name} is not defined')
if isinstance(ret, triton.language.block): if is_triton_tensor(ret):
handle = self.value_constructor.get_value(name) return self._get_tensor(name, self.builder.get_insert_block())
return triton.language.block(handle)
return ret return ret
def set_value(self, name, value): def set_value(self, name: str,
if isinstance(value, _triton.ir.value): value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
value = triton.language.block(value) ''' This function:
if isinstance(value, triton.language.block): called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
self.value_constructor.set_value(name, value.handle) 1. record local defined name (FIXME: should consider control flow)
self.value_constructor.set_type(name, value.handle.type) 2. store tensor in self.lvalue
'''
self.lscope[name] = value self.lscope[name] = value
if isinstance(value, triton.language.tensor):
self._set_value(name, self.builder.get_insert_block(), value)
def is_triton_object(self, value): #
return isinstance(value, triton.language.block) # SSA-construction
#
def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
# local value numbering
if (name, bb) in self.lvalues:
return self.lvalues[(name, bb)]
# global value numbering
saved_insert_point = self.builder.get_insert_point()
result = self._get_tensor_recursive(name, bb)
self.builder.set_insert_point(saved_insert_point)
return result
def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor:
preds = bb.get_predecessors()
type = self.lscope[name].type
# some preds haven't been filled, create a phi as a proxy of the value
if bb not in self.sealed_blocks:
result = self._make_phi(type, len(preds), bb)
if bb in self.incomplete_phis:
self.incomplete_phis[bb][name] = result
else:
self.incomplete_phis[bb] = {name: result}
elif len(preds) == 1:
# one predecessor: no phi needed, try get value from pred
result = self._get_tensor(name, preds[0])
elif len(preds) == 0:
result = self._get_tensor(name, None)
else: # multiple preds
phi = self._make_phi(type, len(preds), bb)
self._set_value(name, bb, phi)
result = self._add_phi_operands(name, phi)
self._set_value(name, bb, result)
return result
# returns a new phi tensor, which encausulate an ir.phi_node
def _make_phi(self,
type: triton.language.dtype,
num_values: int,
bb: _triton.ir.basic_block) -> triton.language.tensor:
instr = bb.get_first_non_phi()
self.builder.set_insert_point((bb, instr))
ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values)
if instr:
self.builder.set_insert_block(bb)
return triton.language.tensor(ir_phi, type)
# complete a phi node. (TODO: rename this as _complete_phis?)
# Note: since we try to remove tryival phi, the return tensor might not be a phi
def _add_phi_operands(self, name: str,
phi: triton.language.tensor) -> triton.language.tensor:
bb = phi.handle.get_parent()
for pred in bb.get_predecessors():
v = self._get_tensor(name, pred)
phi.handle.add_incoming(v.handle, pred)
phi = self._try_remove_trivial_phi(phi)
return phi
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
self.lvalues[(name, bb)] = value
# TODO: why we need this?
self.module.set_instr_metadata(name, value.handle)
def _seal_block(self, bb: _triton.ir.basic_block):
# complete all incomplete phis
if bb in self.incomplete_phis:
for name, phi in self.incomplete_phis[bb].items():
result = self._add_phi_operands(name, phi)
# it's possible that this phi is trivial
if self._get_tensor(name, bb).handle == phi.handle:
self._set_value(name, bb, result)
del self.incomplete_phis[bb]
self.sealed_blocks.add(bb)
def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor:
unique_handles = {op for op in phi.handle.ops() if op != phi.handle}
if len(unique_handles) != 1: # non-trivial phi
return phi
v = unique_handles.pop()
phi.handle.replace_all_uses_with(v)
phi.handle.erase_from_parent()
# TODO: remove trivial phis recursively
return triton.language.tensor(v, phi.type)
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, prototypes=None, module=None, is_kernel=False):
self.prototypes = dict() if prototypes is None else prototypes
self.builder = _triton.ir.builder(context)
self.module = _triton.ir.module('', self.builder) if module is None else module
self.prototype = prototype
self.attributes = attributes
self.constants = constants
self.last_node = None
self.is_kernel = is_kernel
self.value_constructor = ValueConstructor(self.module, self.builder, gscope)
#
# AST visitor
#
def visit_compound_statement(self, stmts): def visit_compound_statement(self, stmts):
for stmt in stmts: for stmt in stmts:
@@ -93,27 +227,6 @@ class CodeGenerator(ast.NodeVisitor):
break break
return stmts and isinstance(stmt, ast.Return) return stmts and isinstance(stmt, ast.Return)
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False):
self.builder = _triton.ir.builder(context)
self.value_constructor = _triton.ir.value_constructor(self.builder)
self.module = _triton.ir.module('', self.builder) if module is None else module
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.attributes = attributes
self.constants = constants
self.last_node = None
self.is_kernel = is_kernel
self.builtins = {
'range': range,
'min': triton.language.minimum,
'float': float,
'int': int,
'print': print,
'isinstance': isinstance,
'getattr': getattr,
}
def visit_Module(self, node): def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
@@ -127,16 +240,10 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Return(self, node): def visit_Return(self, node):
ret = self.visit(node.value) ret = self.visit(node.value)
if ret is None: if ret is None:
return self.builder.ret_void() return triton.language.tensor(self.builder.ret_void(), triton.language.void)
if isinstance(ret, _triton.ir.value): ret = triton.language.core._to_tensor(ret, self.builder)
ret = self.builder.ret(ret) ret = triton.language.tensor(self.builder.ret(ret.handle), ret.type)
return ret return ret
if isinstance(ret, triton.language.block):
ret = ret.handle
if isinstance(ret, triton.language.constexpr):
ret = triton.language.core._to_ir(ret, self.builder)
# TODO: should return tl.block
return self.builder.ret(ret)
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args) arg_names, kwarg_names = self.visit(node.args)
@@ -152,8 +259,9 @@ class CodeGenerator(ast.NodeVisitor):
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
self.visit(init_node) self.visit(init_node)
# initialize function # initialize function
fn_name = mangle_fn(node.name, self.prototype.arg_tys, self.constants) fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants)
fn = self.module.get_or_insert_function(fn_name, self.prototype) self.prototypes[fn_name] = self.prototype
fn = self.module.get_or_insert_function(fn_name, self.prototype.to_ir(self.builder))
fn.set_is_kernel(self.is_kernel) fn.set_is_kernel(self.is_kernel)
arg_values = [] arg_values = []
idx = 0 idx = 0
@@ -171,23 +279,24 @@ class CodeGenerator(ast.NodeVisitor):
attr = _triton.ir.attribute(attr, self.attributes[i]) attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(idx + 1, attr) fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name fn.args[idx].name = arg_name
arg_values.append(fn.args[idx]) arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
idx += 1 idx += 1
insert_pt = self.builder.get_insert_block() insert_pt = self.builder.get_insert_block()
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.builder.set_insert_block(entry) self.builder.set_insert_block(entry)
self.value_constructor.seal_block(entry) self.value_constructor._seal_block(entry)
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.value_constructor.set_value(arg_name, arg_value)
# visit function body # visit function body
has_ret = self.visit_compound_statement(node.body) has_ret = self.visit_compound_statement(node.body)
# finalize # finalize
if not has_ret: if not has_ret:
self.builder.ret_void() self.builder.ret_void()
else: else:
self.module.reset_ret_ty(fn_name, self.last_ret.type) # a bit hacky: we only know the return type at the last moment so we update type info here
# self.module.reset_ret_type(node.name) self.module.reset_ret_ty(fn_name, self.last_ret.type.to_ir(self.builder))
self.prototype.ret_type = self.last_ret.type
self.builder.set_insert_block(insert_pt) self.builder.set_insert_block(insert_pt)
def visit_arguments(self, node): def visit_arguments(self, node):
@@ -208,13 +317,13 @@ class CodeGenerator(ast.NodeVisitor):
value = self.visit(node.value) value = self.visit(node.value)
# constexpr # constexpr
if annotation == triton.language.constexpr: if annotation == triton.language.constexpr:
if target in self.lscope: if target in self.value_constructor.lscope:
raise ValueError(f'{target} is already defined.' raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.') f' constexpr cannot be reassigned.')
if not isinstance(value, triton.language.constexpr): if not isinstance(value, triton.language.constexpr):
value = triton.language.constexpr(value) value = triton.language.constexpr(value)
self.lscope[target] = value self.value_constructor.lscope[target] = value
return self.lscope[target] return self.value_constructor.lscope[target]
# default: call visit_Assign # default: call visit_Assign
return self.visit_Assign(node) return self.visit_Assign(node)
@@ -229,19 +338,21 @@ class CodeGenerator(ast.NodeVisitor):
names = [names] names = [names]
if not isinstance(values, tuple): if not isinstance(values, tuple):
values = [values] values = [values]
if isinstance(values[0], _triton.ir.value): if isinstance(values[0], triton.language.tensor) \
struct = values[0] and isinstance(values[0].type, triton.language.tuple_type):
ty = struct.type struct = values[0].handle
if ty.is_struct(): tys = values[0].type.element_types
values = [self.builder.extract_value(struct, i) for i in range(ty.num_types)] values = [self.builder.extract_value(struct, i) for i in range(len(tys))]
values = [triton.language.tensor(v, ty) for v, ty in zip(values, tys)]
assert len(values) == len(names) assert len(values) == len(names)
for name, value in zip(names, values): for name, value in zip(names, values):
# TODO: can we store constexpr here to support constant folding?
# by default, constexpr are assigned into python variable # by default, constexpr are assigned into python variable
if isinstance(value, triton.language.constexpr): if isinstance(value, triton.language.constexpr):
value = value.value value = value.value
if not isinstance(value, triton.language.block): if not isinstance(value, triton.language.tensor):
value = triton.language.core._to_ir(value, self.builder) value = triton.language.core._to_tensor(value, self.builder)
self.set_value(name, value) self.value_constructor.set_value(name, value)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
name = node.target.id name = node.target.id
@@ -249,12 +360,12 @@ class CodeGenerator(ast.NodeVisitor):
rhs = ast.BinOp(lhs, node.op, node.value) rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs) assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign) self.visit(assign)
return self.get_value(name) return self.value_constructor.get_value(name)
def visit_Name(self, node): def visit_Name(self, node):
if type(node.ctx) == ast.Store: if type(node.ctx) == ast.Store:
return node.id return node.id
return self.get_value(node.id) return self.value_constructor.get_value(node.id)
def visit_Store(self, node): def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node) ast.NodeVisitor.generic_visit(self, node)
@@ -266,23 +377,22 @@ class CodeGenerator(ast.NodeVisitor):
args = [self.visit(x) for x in node.elts] args = [self.visit(x) for x in node.elts]
mode = type(args[0]) mode = type(args[0])
# tuple of values -- create a struct # tuple of values -- create a struct
if len(args) > 1 and mode == triton.language.block\ if len(args) > 1 and mode == triton.language.tensor\
and all([type(arg) == mode for arg in args]): and all([type(arg) == mode for arg in args]):
args = [arg.handle for arg in args] tuple_ty = triton.language.tuple_type([arg.type for arg in args])
tys = [arg.type for arg in args] ret = _triton.ir.undef.get(tuple_ty.to_ir(self.builder))
struct_ty = _triton.ir.struct_type.get(tys, True)
ret = _triton.ir.undef.get(struct_ty)
for i, arg in enumerate(args): for i, arg in enumerate(args):
ret = self.builder.insert_value(ret, arg, i) ret = self.builder.insert_value(ret, arg.handle, i)
ret = triton.language.tensor(ret, tuple_ty)
return ret return ret
return tuple(args) return tuple(args)
def visit_BinOp(self, node): def visit_BinOp(self, node):
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.right) rhs = self.visit(node.right)
if isinstance(lhs, triton.language.core.constexpr): if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr): if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value rhs = rhs.value
fn = { fn = {
ast.Add: '__add__', ast.Add: '__add__',
@@ -298,9 +408,9 @@ class CodeGenerator(ast.NodeVisitor):
ast.BitOr: '__or__', ast.BitOr: '__or__',
ast.BitXor: '__xor__', ast.BitXor: '__xor__',
}[type(node.op)] }[type(node.op)]
if self.is_triton_object(lhs): if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs): elif is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
@@ -308,15 +418,15 @@ class CodeGenerator(ast.NodeVisitor):
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
if isinstance(cond, triton.language.block): if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder) cond = cond.to(triton.language.int1, _builder=self.builder)
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
self.value_constructor.seal_block(then_bb) self.value_constructor._seal_block(then_bb)
if else_bb: if else_bb:
self.value_constructor.seal_block(else_bb) self.value_constructor._seal_block(else_bb)
self.builder.cond_br(cond.handle, then_bb, else_bb) self.builder.cond_br(cond.handle, then_bb, else_bb)
else: else:
self.builder.cond_br(cond.handle, then_bb, endif_bb) self.builder.cond_br(cond.handle, then_bb, endif_bb)
@@ -331,7 +441,7 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: last statement is a terminator? # TODO: last statement is a terminator?
if not is_terminator: if not is_terminator:
self.builder.br(endif_bb) self.builder.br(endif_bb)
self.value_constructor.seal_block(endif_bb) self.value_constructor._seal_block(endif_bb)
self.builder.set_insert_block(endif_bb) self.builder.set_insert_block(endif_bb)
else: else:
if isinstance(cond, triton.language.constexpr): if isinstance(cond, triton.language.constexpr):
@@ -356,9 +466,9 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.ops) == 1 assert len(node.ops) == 1
lhs = self.visit(node.left) lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0]) rhs = self.visit(node.comparators[0])
if isinstance(lhs, triton.language.core.constexpr): if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr): if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value rhs = rhs.value
if type(node.ops[0]) == ast.Is: if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs) return triton.language.constexpr(lhs is rhs)
@@ -372,9 +482,9 @@ class CodeGenerator(ast.NodeVisitor):
ast.Gt: '__gt__', ast.Gt: '__gt__',
ast.GtE: '__ge__', ast.GtE: '__ge__',
}[type(node.ops[0])] }[type(node.ops[0])]
if self.is_triton_object(lhs): if is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_object(rhs): elif is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:] fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder) return getattr(rhs, fn)(lhs, _builder=self.builder)
else: else:
@@ -385,21 +495,21 @@ class CodeGenerator(ast.NodeVisitor):
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op) return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr): if isinstance(op, triton.language.constexpr):
op = op.value op = op.value
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
ast.Invert: '__invert__', ast.Invert: '__invert__',
}[type(node.op)] }[type(node.op)]
if self.is_triton_object(op): if is_triton_tensor(op):
return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)(_builder=self.builder)
return getattr(op, fn)() return getattr(op, fn)()
def visit_While(self, node): def visit_While(self, node):
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
def continue_fn(): def continue_fn():
cond = self.visit(node.test) cond = self.visit(node.test)
@@ -410,9 +520,9 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
continue_fn() continue_fn()
stop_bb = self.builder.get_insert_block() stop_bb = self.builder.get_insert_block()
self.value_constructor.seal_block(stop_bb) self.value_constructor._seal_block(stop_bb)
self.value_constructor.seal_block(loop_bb) self.value_constructor._seal_block(loop_bb)
self.value_constructor.seal_block(next_bb) self.value_constructor._seal_block(next_bb)
self.builder.set_insert_block(next_bb) self.builder.set_insert_block(next_bb)
for stmt in node.orelse: for stmt in node.orelse:
@@ -422,7 +532,7 @@ class CodeGenerator(ast.NodeVisitor):
assert node.ctx.__class__.__name__ == "Load" assert node.ctx.__class__.__name__ == "Load"
lhs = self.visit(node.value) lhs = self.visit(node.value)
slices = self.visit(node.slice) slices = self.visit(node.slice)
if self.is_triton_object(lhs): if is_triton_tensor(lhs):
return lhs.__getitem__(slices, _builder=self.builder) return lhs.__getitem__(slices, _builder=self.builder)
return lhs[slices] return lhs[slices]
@@ -431,7 +541,7 @@ class CodeGenerator(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
iterator = self.visit(node.iter.func) iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']: if iterator != self.value_constructor.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported') raise RuntimeError('Only `range` iterator currently supported')
# static for loops: all iterator arguments are constexpr # static for loops: all iterator arguments are constexpr
iter_args = [self.visit(arg) for arg in node.iter.args] iter_args = [self.visit(arg) for arg in node.iter.args]
@@ -442,7 +552,7 @@ class CodeGenerator(ast.NodeVisitor):
range = iterator(*iter_args) range = iterator(*iter_args)
if len(range) <= 10: if len(range) <= 10:
for i in iterator(*iter_args): for i in iterator(*iter_args):
self.lscope[node.target.id] = triton.language.constexpr(i) self.value_constructor.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
for stmt in node.orelse: for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt) ast.NodeVisitor.generic_visit(self, stmt)
@@ -465,8 +575,8 @@ class CodeGenerator(ast.NodeVisitor):
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
next_bb = _triton.ir.basic_block.create(self.module.builder.context, "postloop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
def continue_fn(): def continue_fn():
self.visit(step_node) self.visit(step_node)
@@ -481,9 +591,9 @@ class CodeGenerator(ast.NodeVisitor):
# TODO: handle case where body breaks control flow # TODO: handle case where body breaks control flow
continue_fn() continue_fn()
stop_bb = self.builder.get_insert_block() stop_bb = self.builder.get_insert_block()
self.value_constructor.seal_block(stop_bb) self.value_constructor._seal_block(stop_bb)
self.value_constructor.seal_block(loop_bb) self.value_constructor._seal_block(loop_bb)
self.value_constructor.seal_block(next_bb) self.value_constructor._seal_block(next_bb)
self.builder.set_insert_block(next_bb) self.builder.set_insert_block(next_bb)
for stmt in node.orelse: for stmt in node.orelse:
@@ -514,7 +624,7 @@ class CodeGenerator(ast.NodeVisitor):
from inspect import getcallargs from inspect import getcallargs
args = getcallargs(fn.fn, *args, **kws) args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names] args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.block) args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args] else triton.language.constexpr(arg) for arg in args]
# generate function def # generate function def
attributes = dict() attributes = dict()
@@ -523,25 +633,24 @@ class CodeGenerator(ast.NodeVisitor):
# generate call # generate call
args = [None if i in constexprs else arg for i, arg in enumerate(args)] args = [None if i in constexprs else arg for i, arg in enumerate(args)]
arg_vals = [arg.handle for arg in args if arg is not None] arg_vals = [arg.handle for arg in args if arg is not None]
arg_types = [arg.type for arg in arg_vals] arg_types = [arg.type for arg in args if arg is not None]
fn_name = mangle_fn(fn.__name__, arg_types, constants) fn_name = mangle_fn(fn.__name__, arg_types, constants)
# generate function def if necessary # generate function def if necessary
if not self.module.has_function(fn_name): if not self.module.has_function(fn_name):
ret_type = _triton.ir.type.get_void(self.builder.context) ret_type = triton.language.void
prototype = _triton.ir.type.make_function(ret_type, arg_types) prototype = triton.language.function_type(ret_type, arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__ gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module) generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, prototypes=self.prototypes, module=self.module)
generator.visit(fn.parse()) generator.visit(fn.parse())
symbol = self.module.get_function(fn_name) symbol = self.module.get_function(fn_name)
ret = self.builder.call(symbol, arg_vals) ret = self.builder.call(symbol, arg_vals)
if not ret.type.is_void() and not ret.type.is_struct(): if not ret.type.is_void():
ret = triton.language.block(ret) ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type)
return ret return ret
# built-in function # built-in function
if hasattr(fn, '__self__') and self.is_triton_object(fn.__self__) or \ if sys.modules[fn.__module__] is triton.language.core:
sys.modules[fn.__module__] is triton.language.core:
ret = fn(*args, _builder=self.builder, **kws) ret = fn(*args, _builder=self.builder, **kws)
if fn in self.builtins.values(): if fn in self.value_constructor.builtins.values():
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
for arg in args] for arg in args]
ret = fn(*args, **kws) ret = fn(*args, **kws)
@@ -698,7 +807,7 @@ class Kernel:
} }
if hasattr(obj, 'data_ptr'): if hasattr(obj, 'data_ptr'):
return type_names[obj.dtype] return type_names[obj.dtype]
if isinstance(obj, triton.language.core.constexpr): if isinstance(obj, triton.language.constexpr):
obj = obj.value obj = obj.value
if isinstance(obj, int): if isinstance(obj, int):
if -2**31 <= obj < 2**31: if -2**31 <= obj < 2**31:
@@ -730,34 +839,34 @@ class Kernel:
return 'scalar', name return 'scalar', name
@staticmethod @staticmethod
def _to_triton_ir(context, obj): def _to_triton_ir(obj):
which, name = obj which, name = obj
type_map = { type_map = {
'I': _triton.ir.type.get_int32, 'I': triton.language.int32,
'L': _triton.ir.type.get_int64, 'L': triton.language.int64,
'f': _triton.ir.type.get_fp32, 'f': triton.language.float32,
'B': _triton.ir.type.get_int1, 'B': triton.language.int1,
'f8': _triton.ir.type.get_fp8, 'f8': triton.language.float8,
'f16': _triton.ir.type.get_fp16, 'f16': triton.language.float16,
'bf16': _triton.ir.type.get_bf16, 'bf16': triton.language.bfloat16,
'f32': _triton.ir.type.get_fp32, 'f32': triton.language.float32,
'f64': _triton.ir.type.get_fp64, 'f64': triton.language.float64,
'i1': _triton.ir.type.get_int1, 'i1': triton.language.int1,
'i8': _triton.ir.type.get_int8, 'i8': triton.language.int8,
'i16': _triton.ir.type.get_int16, 'i16': triton.language.int16,
'i32': _triton.ir.type.get_int32, 'i32': triton.language.int32,
'i64': _triton.ir.type.get_int64, 'i64': triton.language.int64,
'u8': _triton.ir.type.get_uint8, 'u8': triton.language.uint8,
'u16': _triton.ir.type.get_uint16, 'u16': triton.language.uint16,
'u32': _triton.ir.type.get_uint32, 'u32': triton.language.uint32,
'u64': _triton.ir.type.get_uint64, 'u64': triton.language.uint64,
} }
# convert torch.Tensor to Triton IR pointers # convert torch.Tensor to Triton IR pointers
if which == 'ptr': if which == 'ptr':
elt_ty = type_map[name](context) elt_ty = type_map[name]
return _triton.ir.type.make_ptr(elt_ty, 1) return triton.language.pointer_type(elt_ty, 1)
# default path returns triton.ir.type directly # default path returns triton.ir.type directly
return type_map[name](context) return type_map[name]
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
@@ -1121,9 +1230,9 @@ class JITFunction:
# create IR module # create IR module
context = _triton.ir.context() context = _triton.ir.context()
# get just-in-time proto-type of kernel # get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in arg_types] arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
ret_type = _triton.ir.type.get_void(context) ret_type = triton.language.void
prototype = _triton.ir.type.make_function(ret_type, arg_types) prototype = triton.language.function_type(ret_type, arg_types)
# generate Triton-IR # generate Triton-IR
# export symbols visible from self into code-generator object # export symbols visible from self into code-generator object
gscope = self.__globals__ gscope = self.__globals__

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff