[CODEGEN] Performance improvement on A100 (#125)

Improved codegen for the Ampere GPUs.

    * Make the layout pass recognize the multistage pipelined pattern.
    * Now the pipeline pass can automate the multistage pipelining transformation.
    * Remove extra barriers (from the prefetch pass & WAR) on Ampere.
    * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
daadaada
2021-06-21 14:25:13 +08:00
committed by Philippe Tillet
parent 5a51f3e529
commit d8d6b715c8
21 changed files with 855 additions and 174 deletions

View File

@@ -141,10 +141,19 @@ struct double_buffer_info_t {
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
@@ -158,6 +167,10 @@ public:
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
@@ -169,6 +182,7 @@ private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;

View File

@@ -21,7 +21,7 @@ namespace codegen{
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);

View File

@@ -223,6 +223,10 @@ private:
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;

View File

@@ -32,6 +32,8 @@ class shared_layout;
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
@@ -40,6 +42,7 @@ private:
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
@@ -47,14 +50,16 @@ private:
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};

View File

@@ -14,11 +14,13 @@ namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform

View File

@@ -1,9 +1,12 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
@@ -13,9 +16,11 @@ class target;
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}

View File

@@ -832,6 +832,7 @@ public:
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
void set_N(int n) { N_ = n; }
private:
int N_;

View File

@@ -1,5 +1,3 @@
#pragma once
#ifndef _TRITON_IR_PRINT_H_
#define _TRITON_IR_PRINT_H_
@@ -9,8 +7,14 @@ namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
void print(module &mod, std::ostream& os);
void print(function &func, std::ostream& os);
void print(basic_block &bb, std::ostream& os);
void print(instruction &instr, std::ostream& os);
}
}