[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,6 +2,9 @@
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
namespace triton {
@@ -9,6 +12,7 @@ namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
@@ -29,18 +33,15 @@ namespace transform{
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
bool intersect(const val_set_t &X, const val_set_t &Y);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -16,6 +16,10 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
@@ -33,11 +37,12 @@ private:
private:
public:
peephole(target* tgt): tgt_(tgt) {}
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
void run(ir::module &module);
private:
bool has_copy_async_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend);
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public:
host_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership);
cu_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};

View File

@@ -35,6 +35,7 @@ public:
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
// Types
@@ -149,7 +150,7 @@ public:
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
value *create_async_wait();
value *create_async_wait(int N);
private:
context &ctx_;

View File

@@ -92,6 +92,7 @@ private:
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
@@ -803,14 +804,18 @@ public:
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; }
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
private:
int N_;
};
// On NVIDIA, implementation is such that

View File

@@ -98,6 +98,8 @@ private:
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
};
class function {

View File

@@ -30,11 +30,8 @@ private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr;
std::vector<size_t> times;
double total_time = 0;