Merge branch 'master' into rcom52_fixes
This commit is contained in:
@@ -12,7 +12,9 @@ namespace ir {
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class cmp_inst;
|
||||
class reshape_inst;
|
||||
class dequantize_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
@@ -33,8 +35,10 @@ private:
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
@@ -42,6 +46,7 @@ private:
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
@@ -52,6 +57,7 @@ private:
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
@@ -65,6 +71,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value* v, unsigned ax) const;
|
||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||
std::vector<cst_info> get_cst_info(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
|
@@ -25,6 +25,7 @@ private:
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_dequantize(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
|
@@ -103,12 +103,70 @@ public:
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
virtual int contig_per_thread(size_t k) = 0;
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
};
|
||||
|
||||
class mma_layout: public distributed_layout {
|
||||
public:
|
||||
enum TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_TF32_TF32_FP32,
|
||||
// integer tensor core instr
|
||||
INT32_INT1_INT1_INT32, // Not implemented
|
||||
INT32_INT4_INT4_INT32, // Not implemented
|
||||
INT32_INT8_INT8_INT32, // Not implemented
|
||||
//
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
// Used on nvidia GPUs with sm >= 80
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {16, 8, 16}},
|
||||
{FP32_BF16_BF16_FP32, {16, 8, 16}},
|
||||
{FP32_TF32_TF32_FP32, {16, 8, 8}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {16, 8, 256}},
|
||||
{INT32_INT4_INT4_INT32, {16, 8, 64}},
|
||||
{INT32_INT8_INT8_INT32, {16, 8, 32}},
|
||||
};
|
||||
|
||||
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {8, 8, 8}},
|
||||
{FP32_BF16_BF16_FP32, {8, 8, 8}},
|
||||
{FP32_TF32_TF32_FP32, {8, 8, 4}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {8, 8, 64}},
|
||||
{INT32_INT4_INT4_INT32, {8, 8, 32}},
|
||||
{INT32_INT8_INT8_INT32, {8, 8, 16}},
|
||||
};
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
|
||||
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||
|
||||
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
};
|
||||
|
||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
|
||||
{FP32_FP16_FP16_FP32, 8},
|
||||
{FP32_BF16_BF16_FP32, 8},
|
||||
{FP32_TF32_TF32_FP32, 4},
|
||||
|
||||
{INT32_INT1_INT1_INT32, 128},
|
||||
{INT32_INT4_INT4_INT32, 32},
|
||||
{INT32_INT8_INT8_INT32, 16},
|
||||
};
|
||||
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
@@ -116,13 +174,25 @@ public:
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b);
|
||||
shared_layout* layout_b,
|
||||
ir::value *dot);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
|
||||
|
||||
// helpers for generator.cc
|
||||
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
|
||||
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
|
||||
// setter
|
||||
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
|
||||
|
||||
private:
|
||||
// fragment per warp
|
||||
@@ -135,6 +205,10 @@ private:
|
||||
std::vector<int> spt_;
|
||||
// repetitions
|
||||
std::vector<int> rep_;
|
||||
// contiguous per thread
|
||||
std::vector<int> contig_per_thread_;
|
||||
|
||||
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||
};
|
||||
|
||||
class scanline_layout: public distributed_layout {
|
||||
@@ -149,7 +223,9 @@ public:
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||
|
||||
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
|
||||
private:
|
||||
// micro tile size. The size of a tile held by a thread block.
|
||||
std::vector<int> mts_;
|
||||
@@ -170,7 +246,7 @@ struct N_buffer_info_t {
|
||||
std::map<ir::value*, int> firsts_idx;
|
||||
};
|
||||
|
||||
// abstract for dot and coresponding smem values
|
||||
// abstract for dot and corresponding smem values
|
||||
class shared_layout: public data_layout {
|
||||
private:
|
||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
@@ -183,7 +259,8 @@ public:
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align);
|
||||
analysis::align* align, target *tgt,
|
||||
bool is_tmp = false);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
@@ -198,7 +275,10 @@ public:
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
bool is_tmp() const { return is_tmp_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
@@ -210,6 +290,10 @@ private:
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
bool is_tmp_;
|
||||
};
|
||||
|
||||
|
||||
@@ -228,13 +312,20 @@ private:
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
void create_tmp_layout(size_t id, data_layout* arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
ir::instruction* i,
|
||||
bool is_index = false);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
@@ -242,7 +333,19 @@ public:
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
|
||||
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
|
||||
// layout checkers
|
||||
bool is_scanline(ir::instruction* i);
|
||||
|
||||
bool is_coalesced_scanline(ir::instruction* i);
|
||||
|
||||
bool is_mma(ir::instruction* i);
|
||||
|
||||
bool is_a100_mma(ir::instruction* i);
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
@@ -256,6 +359,7 @@ private:
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
std::map<ir::value*, size_t> tmp_index_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,12 +1,14 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
@@ -42,14 +44,14 @@ struct segment {
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
segment get(shared_layout* v) const { return intervals_.lookup(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
|
90
include/triton/codegen/extern_lib.h
Normal file
90
include/triton/codegen/extern_lib.h
Normal file
@@ -0,0 +1,90 @@
|
||||
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
|
||||
///
|
||||
/// \brief ExternLib is a class that represents a library of external functions.
|
||||
///
|
||||
class ExternLib {
|
||||
public:
|
||||
ExternLib(const std::string &name, const std::string &path)
|
||||
: name_(name), path_(path) {}
|
||||
|
||||
virtual ~ExternLib() = default;
|
||||
|
||||
virtual const std::string &name() const { return name_; }
|
||||
|
||||
virtual const std::string &path() const { return path_; }
|
||||
|
||||
///
|
||||
/// \brief Load the library and return the module.
|
||||
///
|
||||
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
|
||||
|
||||
///
|
||||
/// \brief Link the module into the given module.
|
||||
///
|
||||
void link(std::unique_ptr<llvm::Module> &llvm,
|
||||
std::unique_ptr<llvm::Module> &mod);
|
||||
|
||||
///
|
||||
/// \brief Run load, link, and opt on the module.
|
||||
///
|
||||
virtual void install(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) {
|
||||
auto mod = load(ctx);
|
||||
link(llvm, mod);
|
||||
opt(ctx, llvm);
|
||||
}
|
||||
|
||||
///
|
||||
/// \brief Run opt on the module.
|
||||
///
|
||||
virtual void opt(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string path_;
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
|
||||
///
|
||||
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
|
||||
|
||||
///
|
||||
/// \brief Concrete class for NVIDIA's libdevice library.
|
||||
///
|
||||
class LibDevice final : public ExternLib {
|
||||
public:
|
||||
LibDevice(const std::string &name, const std::string &path)
|
||||
: ExternLib(name, path) {}
|
||||
|
||||
virtual ~LibDevice() = default;
|
||||
|
||||
virtual void opt(llvm::LLVMContext &ctx,
|
||||
std::unique_ptr<llvm::Module> &llvm) override;
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Create an ExternLib instance based on the name and path.
|
||||
///
|
||||
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
|
||||
const std::string &lib_path);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
#include <memory>
|
||||
#include "extern_lib.h"
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
@@ -30,12 +31,10 @@ namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
|
||||
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
|
||||
int num_warps, int num_stages, int &shared_static,
|
||||
const ExternLibMap &extern_libs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -4,7 +4,9 @@
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/extern_lib.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
@@ -24,6 +26,7 @@ namespace llvm{
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
class StructType;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
@@ -114,8 +117,17 @@ private:
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
void packed_type(ir::value* i);
|
||||
void forward_declare(ir::function* fn);
|
||||
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
|
||||
|
||||
public:
|
||||
private:
|
||||
typedef std::function<void(
|
||||
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
|
||||
std::function<Value *()> load_index_fn, bool is_first)>
|
||||
acc_fn_t;
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
@@ -125,6 +137,8 @@ public:
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_call_inst(ir::call_inst*);
|
||||
void visit_launch_inst(ir::launch_inst *);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
@@ -134,9 +148,19 @@ public:
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
);
|
||||
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
|
||||
Value *in0, Value *scale_x512, Value *shift
|
||||
);
|
||||
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
|
||||
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
|
||||
void visit_dequantize_inst(ir::dequantize_inst*);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
@@ -148,6 +172,8 @@ public:
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_extract_value_inst(ir::extract_value_inst *);
|
||||
void visit_insert_value_inst(ir::insert_value_inst *);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
@@ -168,8 +194,8 @@ public:
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
@@ -182,6 +208,9 @@ public:
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
void visit_clock_inst(ir::clock_inst*);
|
||||
void visit_globaltimer_inst(ir::globaltimer_inst*);
|
||||
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
@@ -197,12 +226,21 @@ public:
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
// Add a new external library based on given name and path if it doesn't exist
|
||||
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
|
||||
|
||||
private:
|
||||
// Get all external libraries
|
||||
const ExternLibMap &get_extern_lib_map() {
|
||||
return extern_lib_map_;
|
||||
}
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
@@ -235,10 +273,11 @@ private:
|
||||
/// idx for multi-stage pipeline
|
||||
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||
|
||||
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
std::map<ir::value*, Function*> fns_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
@@ -250,6 +289,9 @@ private:
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
|
||||
// Eviction policies
|
||||
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -32,11 +32,12 @@ private:
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool has_sm80_;
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
@@ -15,22 +15,30 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
bool is_shmem_op(ir::instruction* i, int op);
|
||||
bool is_shmem_res(ir::value* i);
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
bool has_sm80_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
31
include/triton/codegen/transform/inline.h
Normal file
31
include/triton/codegen/transform/inline.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class call_inst;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
struct fncmp {
|
||||
bool operator()(ir::function* x, ir::function* y) const;
|
||||
};
|
||||
|
||||
class inliner {
|
||||
public:
|
||||
inliner() {}
|
||||
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -30,12 +30,15 @@ private:
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
@@ -88,6 +88,7 @@ public:
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
// link management
|
||||
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
@@ -214,6 +215,7 @@ private:
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
// link management
|
||||
static void* cuLinkAddFile_v2_;
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
|
@@ -9,8 +9,9 @@ namespace triton{
|
||||
namespace driver{
|
||||
|
||||
void init_llvm();
|
||||
std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, int cc);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
|
5999
include/triton/external/CUDA/cuda.h
vendored
Executable file → Normal file
5999
include/triton/external/CUDA/cuda.h
vendored
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
||||
#define _TRITON_IR_BASIC_BLOCK_H_
|
||||
@@ -27,7 +27,7 @@ public:
|
||||
|
||||
private:
|
||||
// constructors
|
||||
basic_block(context &ctx, const std::string &name, function *parent);
|
||||
basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
@@ -35,6 +35,7 @@ public:
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// get iterator to first instruction that is not a phi
|
||||
void replace_phi_uses_with(basic_block* before, basic_block* after);
|
||||
iterator get_first_non_phi();
|
||||
|
||||
// get instruction list
|
||||
@@ -60,13 +61,16 @@ public:
|
||||
inline const instruction &back() const { return *inst_list_.back(); }
|
||||
inline instruction &back() { return *inst_list_.back(); }
|
||||
|
||||
void append_instruction(ir::instruction* i);
|
||||
// split
|
||||
basic_block* split_before(ir::instruction* loc, const std::string& name);
|
||||
|
||||
// predecessors
|
||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
||||
void add_predecessor(basic_block* pred);
|
||||
std::vector<basic_block*> get_predecessors() const;
|
||||
std::vector<basic_block*> get_successors() const;
|
||||
|
||||
// factory functions
|
||||
static basic_block* create(context &ctx, const std::string &name, function *parent);
|
||||
static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
|
@@ -22,13 +22,16 @@ class phi_node;
|
||||
|
||||
/* Builder */
|
||||
class builder{
|
||||
public:
|
||||
typedef basic_block::iterator iterator;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// const context& get_context() const { return ctx_; }
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
@@ -38,8 +41,8 @@ public:
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_float64(float val);
|
||||
@@ -51,7 +54,9 @@ public:
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
@@ -68,8 +73,13 @@ public:
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
value* create_ret(value *ret);
|
||||
// Dequantize instructions
|
||||
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
@@ -79,6 +89,9 @@ public:
|
||||
value* create_fp_trunc(value *src, type *dst_ty);
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||
value *create_downcast(value *arg);
|
||||
// Call instruction
|
||||
value* create_call(function* fn, const std::vector<value*>& args);
|
||||
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||
// Binary instructions
|
||||
@@ -88,11 +101,11 @@ public:
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
@@ -131,25 +144,48 @@ public:
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
|
||||
// Struct instructions
|
||||
value *create_insert_value(value* val, value *elt, size_t idx);
|
||||
value *create_extract_value(value* val, size_t idx);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Utilities
|
||||
value *create_clock();
|
||||
value *create_globaltimer();
|
||||
// Extern instruction
|
||||
value *create_extern_elementwise(const std::string &lib_name,
|
||||
const std::string &lib_path,
|
||||
const std::string &symbol_name,
|
||||
const std::vector<value *> &args,
|
||||
type *ret_ty);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
@@ -158,7 +194,7 @@ public:
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
@@ -9,7 +9,6 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class builder;
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
@@ -21,7 +20,6 @@ public:
|
||||
context& operator=(const context&) = delete;
|
||||
|
||||
public:
|
||||
ir::builder* builder = nullptr;
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
|
@@ -3,17 +3,15 @@
|
||||
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
|
||||
#define _TRITON_IR_CONTEXT_IMPL_H_
|
||||
|
||||
#include <map>
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class constant;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
@@ -29,16 +27,17 @@ public:
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
|
||||
|
||||
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
||||
// Struct types
|
||||
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
|
||||
// Int constants
|
||||
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
|
||||
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
||||
// Float constants
|
||||
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
|
||||
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
|
||||
|
||||
};
|
||||
|
||||
|
@@ -1,110 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_DISPATCH_H_
|
||||
#define _TRITON_IR_DISPATCH_H_
|
||||
|
||||
#include "triton/ir/builder.h"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/*----------------------------------------------
|
||||
higher level functions that follow the likely
|
||||
semantics of most expected frontends
|
||||
----------------------------------------------*/
|
||||
|
||||
struct semantic_error: public std::runtime_error {
|
||||
semantic_error(const std::string& msg):
|
||||
std::runtime_error(msg) { }
|
||||
};
|
||||
|
||||
struct dispatch{
|
||||
typedef ir::type::block_shapes_t shape_t;
|
||||
|
||||
|
||||
// programming model
|
||||
static ir::value *program_id(int axis, ir::builder *builder);
|
||||
static ir::value *num_programs(int axis, ir::builder *builder);
|
||||
|
||||
// binary operators
|
||||
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// unary operators
|
||||
static ir::value *plus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *minus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *invert(ir::value *input, ir::builder *builder);
|
||||
|
||||
// comparison operators
|
||||
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// block creation
|
||||
static ir::value* arange(int start, int end, ir::builder *builder);
|
||||
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
|
||||
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
||||
// indexing
|
||||
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
||||
|
||||
// reduction
|
||||
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *cos(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sin(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *debug_barrier(ir::builder *builder);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -95,6 +95,9 @@ enum value_id_t: unsigned {
|
||||
INSTRUCTIONS
|
||||
* ------------ */
|
||||
INST_BEGIN,
|
||||
// call
|
||||
INST_CALL,
|
||||
INST_LAUNCH,
|
||||
// phi
|
||||
INST_PHI,
|
||||
// arithmetic
|
||||
@@ -105,6 +108,8 @@ enum value_id_t: unsigned {
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// dequantize
|
||||
INST_DEQUANTIZE,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
@@ -129,6 +134,9 @@ enum value_id_t: unsigned {
|
||||
INST_MASKED_LOAD_ASYNC,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// struct
|
||||
INST_EXTRACT_VALUE,
|
||||
INST_INSERT_VALUE,
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
@@ -148,6 +156,8 @@ enum value_id_t: unsigned {
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// extern
|
||||
INST_EXTERN_ELEMENTWISE,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
@@ -165,6 +175,8 @@ enum value_id_t: unsigned {
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
INST_GLOBALTIMER,
|
||||
INST_CLOCK,
|
||||
};
|
||||
|
||||
|
||||
|
@@ -24,7 +24,7 @@ public:
|
||||
static argument* create(type *ty, const std::string &name,
|
||||
function *parent = nullptr, unsigned arg_no = 0);
|
||||
function* get_parent() const;
|
||||
unsigned get_arg_no() const;
|
||||
unsigned get_arg_no() const;
|
||||
|
||||
void accept(visitor *v);
|
||||
|
||||
@@ -112,7 +112,7 @@ public:
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
@@ -121,6 +121,8 @@ public:
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
|
||||
bool get_is_kernel() { return is_kernel_; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
@@ -134,6 +136,7 @@ private:
|
||||
args_t args_;
|
||||
blocks_t blocks_;
|
||||
attr_map_t attrs_;
|
||||
bool is_kernel_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -59,8 +59,8 @@ public:
|
||||
std::string repr() const { return repr_impl(); }
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
std::vector<unsigned> value) { metadatas_[kind] = value;}
|
||||
std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
@@ -77,10 +77,55 @@ public:
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
|
||||
value_id_t id_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// call_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class call_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
|
||||
|
||||
public:
|
||||
static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
|
||||
ir::function* get_fn() { return fn_; }
|
||||
|
||||
_TRITON_DEFINE_CLONE(call_inst)
|
||||
_TRITON_DEFINE_ACCEPT(call_inst)
|
||||
|
||||
private:
|
||||
ir::function* fn_;
|
||||
};
|
||||
|
||||
class launch_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "launch"; }
|
||||
launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
|
||||
const std::string& name = "", instruction* next = nullptr);
|
||||
|
||||
ir::function* get_fn();
|
||||
std::vector<ir::value*> get_values();
|
||||
std::vector<ir::value*> get_grid();
|
||||
ir::value* get_num_warps();
|
||||
|
||||
|
||||
_TRITON_DEFINE_CLONE(launch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(launch_inst)
|
||||
|
||||
private:
|
||||
unsigned val_begin;
|
||||
unsigned val_end;
|
||||
unsigned grid_begin;
|
||||
unsigned grid_end;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
@@ -117,6 +162,7 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
@@ -145,6 +191,10 @@ public:
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Approx
|
||||
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
|
||||
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
@@ -163,6 +213,8 @@ public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
|
||||
bool fdiv_ieee_rnd_;
|
||||
};
|
||||
|
||||
|
||||
@@ -222,6 +274,24 @@ protected:
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// dequantize_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class dequantize_inst: public instruction{
|
||||
private:
|
||||
std::string repr_impl() const override { return "dequantize"; }
|
||||
|
||||
protected:
|
||||
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(dequantize_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dequantize_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
@@ -383,13 +453,31 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
public:
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
return "";
|
||||
}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
|
||||
protected:
|
||||
EVICTION_POLICY eviction_;
|
||||
};
|
||||
|
||||
// load
|
||||
@@ -399,33 +487,42 @@ public:
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
return "";
|
||||
}
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
return is_volatile_ ? ".volatile" : "";
|
||||
}
|
||||
bool is_volatile_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
@@ -436,7 +533,7 @@ public:
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
@@ -445,7 +542,8 @@ public:
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
@@ -455,9 +553,10 @@ public:
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
|
||||
const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
@@ -466,6 +565,7 @@ public:
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
@@ -477,7 +577,7 @@ public:
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
@@ -488,11 +588,11 @@ public:
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
@@ -502,20 +602,58 @@ public:
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// struct classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// insert_value
|
||||
|
||||
class insert_value_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "insertvalue"; }
|
||||
insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
|
||||
size_t get_idx() { return idx_; }
|
||||
_TRITON_DEFINE_CLONE(insert_value_inst)
|
||||
_TRITON_DEFINE_ACCEPT(insert_value_inst)
|
||||
|
||||
private:
|
||||
size_t idx_;
|
||||
};
|
||||
|
||||
// extract_value
|
||||
|
||||
class extract_value_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "extractvalue"; }
|
||||
extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
|
||||
size_t get_idx() { return idx_; }
|
||||
_TRITON_DEFINE_CLONE(extract_value_inst)
|
||||
_TRITON_DEFINE_ACCEPT(extract_value_inst)
|
||||
|
||||
private:
|
||||
size_t idx_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -641,6 +779,8 @@ private:
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
|
||||
io_inst(ty, id, num_ops, NORMAL, name, next) {}
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
@@ -728,24 +868,40 @@ public:
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
enum DataType {
|
||||
FP8, FP16, BF16, TF32, FP32,
|
||||
INT1, INT4, INT8, INT32,
|
||||
UNKNOWN,
|
||||
};
|
||||
|
||||
private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
|
||||
bool is_prefetched_ = false;
|
||||
public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
bool is_trans_a() const { return AT_ == Trans; }
|
||||
bool is_trans_b() const { return BT_ == Trans; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||
|
||||
private:
|
||||
bool is_prefetched_ = false;
|
||||
bool allow_tf32_ = false;
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
TransT AT_;
|
||||
TransT BT_;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
@@ -787,8 +943,11 @@ public:
|
||||
class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN,
|
||||
FADD, FSUB, FMAX, FMIN
|
||||
ADD, SUB, MAX, MIN, UMAX, UMIN,
|
||||
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
ARGFMAX, ARGFMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -805,12 +964,19 @@ public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
bool with_index() const {
|
||||
return with_index_ops_.find(op_) != with_index_ops_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
const static inline std::set<op_t> with_index_ops_ = {
|
||||
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
|
||||
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
@@ -898,11 +1064,11 @@ class prefetch_s_inst : public instruction {
|
||||
std::string repr_impl() const { return "prefetch_s"; }
|
||||
_TRITON_DEFINE_CLONE(prefetch_s_inst)
|
||||
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
|
||||
|
||||
|
||||
/// inc_: 0->first, 1->latch
|
||||
int inc_ = 0;
|
||||
public:
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
@@ -928,7 +1094,53 @@ private:
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
/* timing utilities */
|
||||
class clock_inst: public instruction{
|
||||
clock_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "clock"; }
|
||||
_TRITON_DEFINE_CLONE(clock_inst)
|
||||
_TRITON_DEFINE_ACCEPT(clock_inst)
|
||||
|
||||
public:
|
||||
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class globaltimer_inst: public instruction{
|
||||
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "globaltimer"; }
|
||||
_TRITON_DEFINE_CLONE(globaltimer_inst)
|
||||
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
|
||||
|
||||
public:
|
||||
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class extern_elementwise_inst : public instruction {
|
||||
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
|
||||
type *dst_ty, const std::string &lib_name,
|
||||
const std::string &extern_lib_path,
|
||||
const std::string &symbol_name,
|
||||
const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "extern_elementwise"; }
|
||||
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
|
||||
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
|
||||
|
||||
public:
|
||||
static extern_elementwise_inst *create(
|
||||
context &ctx, const std::vector<value *> &args, type *dst_ty,
|
||||
const std::string &lib_name = "", const std::string &lib_path = "",
|
||||
const std::string &symbol_name = "", const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
|
||||
const std::string &get_lib_name() const { return lib_name_; }
|
||||
const std::string &get_lib_path() const { return lib_path_; }
|
||||
const std::string &get_symbol_name() const { return symbol_name_; }
|
||||
|
||||
private:
|
||||
std::string lib_name_;
|
||||
std::string lib_path_;
|
||||
std::string symbol_name_;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -3,6 +3,8 @@
|
||||
#ifndef _TRITON_IR_METADATA_H_
|
||||
#define _TRITON_IR_METADATA_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
@@ -16,14 +18,14 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
metadata(kind_t kind, unsigned value);
|
||||
metadata(kind_t kind, std::vector<unsigned> value);
|
||||
|
||||
public:
|
||||
static metadata* get(kind_t kind, unsigned value);
|
||||
static metadata* get(kind_t kind, std::vector<unsigned> value);
|
||||
|
||||
private:
|
||||
kind_t kind_;
|
||||
unsigned value_;
|
||||
std::vector<unsigned> value_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -34,50 +34,74 @@ class constant;
|
||||
class global_value;
|
||||
class alloc_const;
|
||||
|
||||
/* Module */
|
||||
|
||||
class module {
|
||||
class value_constructor {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
struct current_iteration_info_t{
|
||||
lang::iteration_statement *statement;
|
||||
basic_block *block;
|
||||
};
|
||||
|
||||
private:
|
||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder& builder);
|
||||
builder& get_builder();
|
||||
// Setters
|
||||
value_constructor(builder &builder);
|
||||
|
||||
void set_value(const std::string& name, basic_block* block, value *x);
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_const(const std::string& name);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||
const std::string& get_name();
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
// Metadata
|
||||
|
||||
private:
|
||||
ir::builder& builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<std::string, type*> types_;
|
||||
std::set<basic_block*> sealed_blocks_;
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
};
|
||||
|
||||
/* Module */
|
||||
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
|
||||
friend class function;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
|
||||
private:
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_function(const std::string& name) {
|
||||
if(symbols_.find(name) == symbols_.end())
|
||||
throw std::runtime_error("function " + name + " is not declared");
|
||||
return (function*)symbols_.at(name);
|
||||
}
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
bool has_function(const std::string& name){
|
||||
return symbols_.find(name) != symbols_.end();
|
||||
}
|
||||
void remove_function(ir::function* fn){
|
||||
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
|
||||
}
|
||||
|
||||
void reset_ret_ty(const std::string& name, type* ty);
|
||||
|
||||
// Const allocation
|
||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||
@@ -85,22 +109,15 @@ public:
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder& builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<std::string, type*> types_;
|
||||
std::set<std::string> const_;
|
||||
std::set<basic_block*> sealed_blocks_;
|
||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||
builder &builder_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_TYPE_H_
|
||||
#define _TRITON_IR_TYPE_H_
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -20,7 +21,6 @@ class type {
|
||||
public:
|
||||
typedef std::vector<unsigned> block_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
|
||||
@@ -68,23 +68,24 @@ public:
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
type *get_pointer_element_ty() const;
|
||||
unsigned get_struct_numel() const { return contained_tys_.size(); }
|
||||
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
bool is_struct_ty() const { return id_ == StructTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
@@ -128,6 +129,7 @@ public:
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case FP8TyID: return "fp8";
|
||||
case BF16TyID: return "bf16";
|
||||
case FP16TyID: return "f16";
|
||||
case BF16TyID: return "bf16";
|
||||
case FP32TyID: return "f32";
|
||||
@@ -135,15 +137,14 @@ public:
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case BlockTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
|
||||
};
|
||||
|
||||
private:
|
||||
@@ -160,7 +161,7 @@ class integer_type: public type {
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
@@ -182,6 +183,16 @@ public:
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class struct_type: public composite_type {
|
||||
public:
|
||||
struct_type(const contained_tys_vec_t& tys, bool is_packed);
|
||||
unsigned get_num_types() const { return contained_tys_.size(); }
|
||||
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
|
||||
|
||||
private:
|
||||
bool is_packed_;
|
||||
};
|
||||
|
||||
class block_type: public composite_type {
|
||||
private:
|
||||
block_type(type *ty, const block_shapes_t &shapes);
|
||||
@@ -230,6 +241,7 @@ public:
|
||||
ty_iterator params_end() { return contained_tys_.end(); }
|
||||
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
||||
type* get_return_ty() const { return contained_tys_.at(0); }
|
||||
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
|
||||
// factory methods
|
||||
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
||||
};
|
||||
|
@@ -22,6 +22,7 @@ public:
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
|
@@ -21,7 +21,7 @@ class visitor;
|
||||
|
||||
class value {
|
||||
public:
|
||||
typedef std::set<user*> users_t;
|
||||
typedef std::vector<user*> users_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
// uses
|
||||
void add_use(user* arg);
|
||||
users_t::iterator erase_use(user* arg);
|
||||
const std::set<user*> &get_users() { return users_; }
|
||||
const std::vector<user*> &get_users() { return users_; }
|
||||
void replace_all_uses_with(value *target);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
|
@@ -11,12 +11,16 @@ class value;
|
||||
|
||||
class instruction;
|
||||
|
||||
class call_inst;
|
||||
class launch_inst;
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class dequantize_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
@@ -42,6 +46,9 @@ class masked_load_inst;
|
||||
class unmasked_store_inst;
|
||||
class masked_store_inst;
|
||||
|
||||
class extract_value_inst;
|
||||
class insert_value_inst;
|
||||
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
@@ -75,6 +82,10 @@ class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
class clock_inst;
|
||||
class globaltimer_inst;
|
||||
|
||||
class extern_elementwise_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
@@ -103,6 +114,8 @@ public:
|
||||
virtual ~visitor() {}
|
||||
|
||||
virtual void visit_value(ir::value*);
|
||||
virtual void visit_call_inst(ir::call_inst*) = 0;
|
||||
virtual void visit_launch_inst(ir::launch_inst*) = 0;
|
||||
|
||||
virtual void visit_basic_block(basic_block*) = 0;
|
||||
virtual void visit_argument(argument*) = 0;
|
||||
@@ -112,6 +125,7 @@ public:
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
@@ -130,6 +144,9 @@ public:
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_extract_value_inst(extract_value_inst*) = 0;
|
||||
virtual void visit_insert_value_inst(insert_value_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
@@ -157,11 +174,15 @@ public:
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
virtual void visit_function(function*) = 0;
|
||||
virtual void visit_clock_inst(clock_inst*) = 0;
|
||||
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
|
||||
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
|
||||
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -3,8 +3,9 @@
|
||||
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
@@ -13,21 +14,21 @@ namespace tools{
|
||||
|
||||
template<class node_t>
|
||||
class graph {
|
||||
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||
typedef std::map<node_t, llvm::SetVector<node_t>> edges_t;
|
||||
|
||||
public:
|
||||
typedef std::map<size_t, std::vector<node_t>> cmap_t;
|
||||
typedef std::map<node_t, size_t> nmap_t;
|
||||
|
||||
private:
|
||||
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||
void connected_components_impl(node_t x, llvm::SetVector<node_t> &nodes,
|
||||
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||
if(nmap)
|
||||
(*nmap)[x] = id;
|
||||
if(cmap)
|
||||
(*cmap)[id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()) {
|
||||
nodes.erase(x);
|
||||
if (nodes.count(x)) {
|
||||
nodes.remove(x);
|
||||
for(const node_t &y: edges_.at(x))
|
||||
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||
}
|
||||
@@ -39,7 +40,7 @@ public:
|
||||
cmap->clear();
|
||||
if(nmap)
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
llvm::SetVector<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty()){
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
@@ -59,7 +60,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<node_t> nodes_;
|
||||
llvm::SetVector<node_t> nodes_;
|
||||
edges_t edges_;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user