[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -2,6 +2,9 @@
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
|
||||
namespace triton {
|
||||
|
||||
@@ -9,6 +12,7 @@ namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
@@ -29,18 +33,15 @@ namespace transform{
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
||||
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
|
||||
bool intersect(const interval_vec_t &X, interval_t x);
|
||||
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
||||
void add_reference(ir::value *v, interval_vec_t &res);
|
||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
||||
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||
|
@@ -16,6 +16,10 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
@@ -33,11 +37,12 @@ private:
|
||||
private:
|
||||
|
||||
public:
|
||||
peephole(target* tgt): tgt_(tgt) {}
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
|
28
include/triton/codegen/transform/pipeline.h
Normal file
28
include/triton/codegen/transform/pipeline.h
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -29,7 +29,7 @@ public:
|
||||
static driver::stream* create(backend_t backend);
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -44,7 +44,7 @@ class host_stream: public stream {
|
||||
public:
|
||||
host_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -55,7 +55,7 @@ public:
|
||||
cu_stream(CUstream str, bool take_ownership);
|
||||
cu_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -35,6 +35,7 @@ public:
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
// Types
|
||||
@@ -149,7 +150,7 @@ public:
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait();
|
||||
value *create_async_wait(int N);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -92,6 +92,7 @@ private:
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
void set_incoming_block(unsigned i, basic_block *block);
|
||||
value *get_value_for_block(basic_block *block);
|
||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||
unsigned get_num_incoming() { return get_num_operands(); }
|
||||
@@ -803,14 +804,18 @@ public:
|
||||
|
||||
class async_wait_inst: public instruction{
|
||||
private:
|
||||
async_wait_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait"; }
|
||||
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||
|
||||
public:
|
||||
static async_wait_inst* create(context &ctx, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
static async_wait_inst* create(context &ctx, int N,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
int get_N() { return N_; }
|
||||
|
||||
private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
|
@@ -98,6 +98,8 @@ private:
|
||||
std::shared_ptr<ir::module> ir_;
|
||||
std::shared_ptr<driver::module> mod_;
|
||||
std::shared_ptr<driver::kernel> ker_;
|
||||
// shared mem
|
||||
size_t shared_mem_;
|
||||
};
|
||||
|
||||
class function {
|
||||
|
@@ -30,11 +30,8 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
|
||||
{
|
||||
// const driver::device * device = stream->context()->device();
|
||||
size_t warmup = 10;
|
||||
size_t repeat = 50;
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
|
Reference in New Issue
Block a user