[PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22)
This commit is contained in:
@@ -59,13 +59,13 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w
|
||||
|
||||
thread tile size 2
|
||||
- - - - - - /\ - - - - - -
|
||||
block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
||||
warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3]
|
||||
tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3]
|
||||
size } ....
|
||||
32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63]
|
||||
| A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63]
|
||||
-----------------------------/\-----------------------------------
|
||||
block tile size 8
|
||||
warp tile size 8
|
||||
|
||||
|
||||
A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3]
|
||||
|
@@ -1,80 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class align {
|
||||
private:
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
||||
};
|
||||
// helpers
|
||||
std::vector<unsigned> get_shapes(ir::value *v);
|
||||
// populate is_constant
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
// populate max_contiguous
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
void populate(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value* v, unsigned ax) const;
|
||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
|
||||
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,47 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class tiles;
|
||||
|
||||
class liveness;
|
||||
class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live)
|
||||
: liveness_(live) { }
|
||||
// accessors
|
||||
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
void run(ir::module& mod);
|
||||
|
||||
private:
|
||||
std::map<const data_layout*, unsigned> offsets_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,52 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
|
||||
#include "triton/tools/graph.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
|
||||
private:
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async=false);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
// accessors
|
||||
int get(ir::value *value, unsigned dim);
|
||||
std::vector<int> get(ir::value *value);
|
||||
|
||||
private:
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<node_t, size_t> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,345 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class type;
|
||||
class module;
|
||||
class instruction;
|
||||
class phi_node;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class align;
|
||||
class layout_visitor;
|
||||
class data_layout;
|
||||
class mma_layout;
|
||||
class scanline_layout;
|
||||
class shared_layout;
|
||||
|
||||
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout(data_layout *);
|
||||
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||
};
|
||||
|
||||
class data_layout {
|
||||
protected:
|
||||
enum id_t {
|
||||
MMA,
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
|
||||
typedef std::vector<int> axes_t;
|
||||
typedef std::vector<unsigned> shape_t;
|
||||
typedef std::vector<int> order_t;
|
||||
typedef std::vector<ir::value*> values_t;
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
T* downcast(id_t id) {
|
||||
if(id_ == id)
|
||||
return static_cast<T*>(this);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
public:
|
||||
data_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
// visitor
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
// downcast
|
||||
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||
// accessors
|
||||
size_t get_rank() { return shape_.size(); }
|
||||
const shape_t& get_shape() const { return shape_; }
|
||||
const order_t& get_order() const { return order_; }
|
||||
const values_t& get_values() const { return values_;}
|
||||
int get_axis(size_t k) const { return axes_.at(k); }
|
||||
std::vector<int> get_axes() const { return axes_; }
|
||||
const int get_order(size_t k) const { return order_.at(k); }
|
||||
// find the position of given axis
|
||||
int find_axis(int to_find) const;
|
||||
|
||||
|
||||
private:
|
||||
id_t id_;
|
||||
axes_t axes_;
|
||||
values_t values_;
|
||||
|
||||
protected:
|
||||
order_t order_;
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class distributed_layout: public data_layout{
|
||||
public:
|
||||
distributed_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value*>& values,
|
||||
analysis::align* align);
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
virtual int contig_per_thread(size_t k) = 0;
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
};
|
||||
|
||||
class mma_layout: public distributed_layout {
|
||||
public:
|
||||
enum TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_TF32_TF32_FP32,
|
||||
// integer tensor core instr
|
||||
INT32_INT1_INT1_INT32, // Not implemented
|
||||
INT32_INT4_INT4_INT32, // Not implemented
|
||||
INT32_INT8_INT8_INT32, // Not implemented
|
||||
//
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
// Used on nvidia GPUs with sm >= 80
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {16, 8, 16}},
|
||||
{FP32_BF16_BF16_FP32, {16, 8, 16}},
|
||||
{FP32_TF32_TF32_FP32, {16, 8, 8}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {16, 8, 256}},
|
||||
{INT32_INT4_INT4_INT32, {16, 8, 64}},
|
||||
{INT32_INT8_INT8_INT32, {16, 8, 32}},
|
||||
};
|
||||
|
||||
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {8, 8, 8}},
|
||||
{FP32_BF16_BF16_FP32, {8, 8, 8}},
|
||||
{FP32_TF32_TF32_FP32, {8, 8, 4}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {8, 8, 64}},
|
||||
{INT32_INT4_INT4_INT32, {8, 8, 32}},
|
||||
{INT32_INT8_INT8_INT32, {8, 8, 16}},
|
||||
};
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
|
||||
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||
|
||||
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
};
|
||||
|
||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
|
||||
{FP32_FP16_FP16_FP32, 8},
|
||||
{FP32_BF16_BF16_FP32, 8},
|
||||
{FP32_TF32_TF32_FP32, 4},
|
||||
|
||||
{INT32_INT1_INT1_INT32, 128},
|
||||
{INT32_INT4_INT4_INT32, 32},
|
||||
{INT32_INT8_INT8_INT32, 16},
|
||||
};
|
||||
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b,
|
||||
ir::value *dot);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
|
||||
|
||||
// helpers for generator.cc
|
||||
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
|
||||
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
|
||||
// setter
|
||||
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
|
||||
|
||||
private:
|
||||
// fragment per warp
|
||||
std::vector<int> fpw_;
|
||||
// shape per warp
|
||||
std::vector<int> spw_;
|
||||
// warp per tile
|
||||
std::vector<int> wpt_;
|
||||
// shape per tile
|
||||
std::vector<int> spt_;
|
||||
// repetitions
|
||||
std::vector<int> rep_;
|
||||
// contiguous per thread
|
||||
std::vector<int> contig_per_thread_;
|
||||
|
||||
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||
};
|
||||
|
||||
struct scanline_layout: public distributed_layout {
|
||||
scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||
|
||||
public:
|
||||
// micro tile size. The size of a tile held by a thread block.
|
||||
std::vector<int> mts_;
|
||||
// nano tile size. The size of a tile held by a thread.
|
||||
std::vector<int> nts_;
|
||||
};
|
||||
|
||||
struct double_buffer_info_t {
|
||||
ir::value* first;
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct N_buffer_info_t {
|
||||
std::vector<ir::value*> firsts; // not necessarily ordered as input order
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
std::map<ir::value*, int> firsts_idx;
|
||||
};
|
||||
|
||||
// abstract for dot and coresponding smem values
|
||||
class shared_layout: public data_layout {
|
||||
private:
|
||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
|
||||
|
||||
public:
|
||||
shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
ir::type* get_type() { return ty_; }
|
||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
|
||||
int get_num_stages() const;
|
||||
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
|
||||
size_t get_per_stage_elements() const;
|
||||
size_t get_num_per_phase() { return num_per_phase_; }
|
||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
ir::type *ty_;
|
||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||
std::shared_ptr<N_buffer_info_t> N_buffer_;
|
||||
size_t num_per_phase_;
|
||||
ir::value* hmma_dot_a_;
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class layouts {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
void init_hmma_tile(data_layout& layouts);
|
||||
void init_scanline_tile(data_layout &layouts);
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,67 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
slot_index end;
|
||||
|
||||
bool contains(slot_index idx) const {
|
||||
return start <= idx && idx < end;
|
||||
}
|
||||
|
||||
bool intersect(const segment &Other){
|
||||
return contains(Other.start) || Other.contains(start);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
// analysis
|
||||
layouts *layouts_;
|
||||
intervals_map_t intervals_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,43 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class target;
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
class swizzle {
|
||||
public:
|
||||
// constructor
|
||||
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||
// accessors
|
||||
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
private:
|
||||
layouts* layouts_;
|
||||
target* tgt_;
|
||||
std::map<data_layout*, int> per_phase_;
|
||||
std::map<data_layout*, int> max_phase_;
|
||||
std::map<data_layout*, int> vec_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,42 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_PASS_H_
|
||||
#define _TRITON_CODEGEN_PASS_H_
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
namespace driver{
|
||||
class device;
|
||||
class module;
|
||||
class kernel;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,258 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_GENERATOR_H_
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class PHINode;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class attribute;
|
||||
class load_inst;
|
||||
class store_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
// forward
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
class swizzle;
|
||||
}
|
||||
// typedef
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Attribute Attribute;
|
||||
typedef llvm::BasicBlock BasicBlock;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
class target;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class adder{
|
||||
public:
|
||||
adder(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class multiplier{
|
||||
public:
|
||||
multiplier(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class geper{
|
||||
public:
|
||||
geper(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void init_idx(ir::value *x);
|
||||
Instruction* add_barrier();
|
||||
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||
void finalize_shared_layout(analysis::shared_layout*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
void visit_load_inst(ir::load_inst*);
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
target *tgt_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *shmem_;
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
/// layout -> base ptr
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
/// offset for double-buffered layout
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
/// Base shmem pointer of ir value
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
/// idx for multi-stage pipeline
|
||||
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
/// PHI nodes
|
||||
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,105 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
|
||||
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
|
||||
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
|
||||
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
nvidia_cu_target* as_nvidia();
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
};
|
||||
|
||||
class amd_cl_target: public target {
|
||||
public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
int sm() { return sm_; }
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 1; }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,48 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class io_inst;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class align;
|
||||
class layouts;
|
||||
class cts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class coalesce {
|
||||
private:
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,36 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,24 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class dce {
|
||||
public:
|
||||
dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class disassociate {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,72 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class layouts;
|
||||
class cts;
|
||||
class shared_layout;
|
||||
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class prefetch;
|
||||
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
bool check_safe_war(ir::instruction* i);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
|
||||
transform::prefetch *prefetch, target* tgt):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
transform::prefetch *prefetch_;
|
||||
|
||||
target* tgt_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,53 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
class constant_int;
|
||||
class dot_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async, int num_stages)
|
||||
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
int num_stages_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,27 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
|
||||
#include <set>
|
||||
|
||||
// forward dclaration
|
||||
namespace triton::ir{
|
||||
class module;
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace triton::codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
class prefetch {
|
||||
target* tgt_;
|
||||
std::set<ir::value*> prefetched_vals_;
|
||||
public:
|
||||
prefetch(target *tgt) : tgt_(tgt) {}
|
||||
void run(ir::module &module);
|
||||
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,26 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user