[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -2,6 +2,9 @@
|
|||||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <list>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
|
|
||||||
@@ -9,6 +12,7 @@ namespace ir {
|
|||||||
class module;
|
class module;
|
||||||
class basic_block;
|
class basic_block;
|
||||||
class instruction;
|
class instruction;
|
||||||
|
class masked_load_async_inst;
|
||||||
class value;
|
class value;
|
||||||
class builder;
|
class builder;
|
||||||
}
|
}
|
||||||
@@ -29,18 +33,15 @@ namespace transform{
|
|||||||
class membar {
|
class membar {
|
||||||
private:
|
private:
|
||||||
typedef std::pair<unsigned, unsigned> interval_t;
|
typedef std::pair<unsigned, unsigned> interval_t;
|
||||||
typedef std::vector<interval_t> interval_vec_t;
|
typedef std::set<ir::value*> val_set_t;
|
||||||
|
typedef std::vector<ir::value*> val_vec_t;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||||
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
|
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||||
bool intersect(const interval_vec_t &X, interval_t x);
|
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||||
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||||
void add_reference(ir::value *v, interval_vec_t &res);
|
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
|
||||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
|
||||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
|
||||||
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||||
|
@@ -16,6 +16,10 @@ namespace ir {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
|
namespace analysis{
|
||||||
|
class layouts;
|
||||||
|
}
|
||||||
|
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
class peephole {
|
class peephole {
|
||||||
@@ -33,11 +37,12 @@ private:
|
|||||||
private:
|
private:
|
||||||
|
|
||||||
public:
|
public:
|
||||||
peephole(target* tgt): tgt_(tgt) {}
|
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
target* tgt_;
|
target* tgt_;
|
||||||
|
analysis::layouts* layouts_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
28
include/triton/codegen/transform/pipeline.h
Normal file
28
include/triton/codegen/transform/pipeline.h
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||||
|
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||||
|
|
||||||
|
// forward declaration
|
||||||
|
namespace triton {
|
||||||
|
namespace ir {
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
} // namespace triton
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace codegen {
|
||||||
|
namespace transform {
|
||||||
|
|
||||||
|
class pipeline {
|
||||||
|
public:
|
||||||
|
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
|
||||||
|
void run(ir::module &module);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool has_copy_async_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace transform
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace triton
|
||||||
|
|
||||||
|
#endif
|
@@ -29,7 +29,7 @@ public:
|
|||||||
static driver::stream* create(backend_t backend);
|
static driver::stream* create(backend_t backend);
|
||||||
// methods
|
// methods
|
||||||
virtual void synchronize() = 0;
|
virtual void synchronize() = 0;
|
||||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
|
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
|
||||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||||
// template helpers
|
// template helpers
|
||||||
@@ -44,7 +44,7 @@ class host_stream: public stream {
|
|||||||
public:
|
public:
|
||||||
host_stream();
|
host_stream();
|
||||||
void synchronize();
|
void synchronize();
|
||||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||||
};
|
};
|
||||||
@@ -55,7 +55,7 @@ public:
|
|||||||
cu_stream(CUstream str, bool take_ownership);
|
cu_stream(CUstream str, bool take_ownership);
|
||||||
cu_stream();
|
cu_stream();
|
||||||
void synchronize();
|
void synchronize();
|
||||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
|
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||||
};
|
};
|
||||||
|
@@ -35,6 +35,7 @@ public:
|
|||||||
basic_block* get_insert_block() { return block_; }
|
basic_block* get_insert_block() { return block_; }
|
||||||
iterator get_insert_point() { return insert_point_;}
|
iterator get_insert_point() { return insert_point_;}
|
||||||
// Constants
|
// Constants
|
||||||
|
value *get_int1(bool val);
|
||||||
value *get_int32(int32_t val);
|
value *get_int32(int32_t val);
|
||||||
value *get_int64(int64_t val);
|
value *get_int64(int64_t val);
|
||||||
// Types
|
// Types
|
||||||
@@ -149,7 +150,7 @@ public:
|
|||||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||||
value *create_barrier(const std::string &name = "");
|
value *create_barrier(const std::string &name = "");
|
||||||
value *create_async_wait();
|
value *create_async_wait(int N);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
context &ctx_;
|
context &ctx_;
|
||||||
|
@@ -92,6 +92,7 @@ private:
|
|||||||
public:
|
public:
|
||||||
void set_incoming_value(unsigned i, value *v);
|
void set_incoming_value(unsigned i, value *v);
|
||||||
void set_incoming_block(unsigned i, basic_block *block);
|
void set_incoming_block(unsigned i, basic_block *block);
|
||||||
|
value *get_value_for_block(basic_block *block);
|
||||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||||
unsigned get_num_incoming() { return get_num_operands(); }
|
unsigned get_num_incoming() { return get_num_operands(); }
|
||||||
@@ -803,14 +804,18 @@ public:
|
|||||||
|
|
||||||
class async_wait_inst: public instruction{
|
class async_wait_inst: public instruction{
|
||||||
private:
|
private:
|
||||||
async_wait_inst(context &ctx, const std::string &name, instruction *next);
|
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "async_wait"; }
|
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static async_wait_inst* create(context &ctx, const std::string &name = "",
|
static async_wait_inst* create(context &ctx, int N,
|
||||||
instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
|
int get_N() { return N_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int N_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// On NVIDIA, implementation is such that
|
// On NVIDIA, implementation is such that
|
||||||
|
@@ -98,6 +98,8 @@ private:
|
|||||||
std::shared_ptr<ir::module> ir_;
|
std::shared_ptr<ir::module> ir_;
|
||||||
std::shared_ptr<driver::module> mod_;
|
std::shared_ptr<driver::module> mod_;
|
||||||
std::shared_ptr<driver::kernel> ker_;
|
std::shared_ptr<driver::kernel> ker_;
|
||||||
|
// shared mem
|
||||||
|
size_t shared_mem_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class function {
|
class function {
|
||||||
|
@@ -30,11 +30,8 @@ private:
|
|||||||
high_resolution_clock::time_point _start;
|
high_resolution_clock::time_point _start;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
|
||||||
{
|
{
|
||||||
// const driver::device * device = stream->context()->device();
|
|
||||||
size_t warmup = 10;
|
|
||||||
size_t repeat = 50;
|
|
||||||
timer tmr;
|
timer tmr;
|
||||||
std::vector<size_t> times;
|
std::vector<size_t> times;
|
||||||
double total_time = 0;
|
double total_time = 0;
|
||||||
|
@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
|
|||||||
if(rhs_cst_info[d].num_cst)
|
if(rhs_cst_info[d].num_cst)
|
||||||
rvalue = lhs_max_contiguous[d];
|
rvalue = lhs_max_contiguous[d];
|
||||||
result[d] = std::max(lvalue, rvalue);
|
result[d] = std::max(lvalue, rvalue);
|
||||||
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
|
|
||||||
}
|
}
|
||||||
return add_to_cache(x, result, max_contiguous_);
|
return add_to_cache(x, result, max_contiguous_);
|
||||||
}
|
}
|
||||||
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
|
|||||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||||
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
|
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||||
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
|
|
||||||
// });
|
// });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
|
|||||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||||
}
|
}
|
||||||
if(is_recoalesce){
|
|
||||||
if(ptr.size() > 0){
|
|
||||||
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
|
|
||||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
|
||||||
// if(order_[0] == 0)
|
|
||||||
// exit(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// std::cout << "---" << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int data_layout::find_axis(int to_find) const {
|
int data_layout::find_axis(int to_find) const {
|
||||||
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
ir::value *ptr = nullptr;
|
ir::value *ptr = nullptr;
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
for(ir::user *usr: v->get_users())
|
for(ir::user *usr: v->get_users())
|
||||||
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
|
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||||
ptr = st->get_pointer_operand();
|
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||||
|
ptr = io->get_pointer_operand();
|
||||||
|
}
|
||||||
|
|
||||||
unsigned i = order_[0];
|
unsigned i = order_[0];
|
||||||
int contiguous = 1;
|
int contiguous = 1;
|
||||||
if(ptr){
|
if(ptr){
|
||||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
|
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
|
||||||
}
|
}
|
||||||
|
|
||||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||||
|
@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||||
unsigned vector = 1;
|
unsigned in_vec = 1;
|
||||||
ir::value *ptrs = x->get_pointer_operand();
|
ir::value *arg = x->get_pointer_operand();
|
||||||
ir::value *msks = x->get_mask_operand();
|
|
||||||
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
|
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
|
||||||
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline();
|
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
|
||||||
auto out_order = out_layout->get_order();
|
auto out_order = out_layout->get_order();
|
||||||
auto in_order = in_layout->get_order();
|
auto in_order = in_layout->get_order();
|
||||||
// tiles
|
// tiles
|
||||||
if(out_order == in_order)
|
if(out_order == in_order)
|
||||||
vector = in_layout->nts(in_order[0]);
|
in_vec = in_layout->nts(in_order[0]);
|
||||||
|
int out_vec = swizzle_->get_vec(out_layout);
|
||||||
|
int min_vec = std::min<int>(out_vec, in_vec);
|
||||||
|
int s = std::max<int>(out_vec / in_vec, 1);
|
||||||
//
|
//
|
||||||
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
int per_phase = swizzle_->get_per_phase(out_layout);
|
||||||
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1);
|
int max_phase = swizzle_->get_max_phase(out_layout);
|
||||||
Value *max_phase = i32(8 / num_per_phase);
|
|
||||||
//
|
//
|
||||||
|
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
||||||
|
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
|
||||||
|
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||||
auto shapes = x->get_type()->get_tile_shapes();
|
auto shapes = x->get_type()->get_tile_shapes();
|
||||||
//
|
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||||
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||||
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1);
|
std::map<std::pair<int, int>, Value*> tmp;
|
||||||
std::vector<Value*> shared;
|
std::vector<std::pair<Value*, int>> shared;
|
||||||
for(size_t i = 0; i < n_shared; i++){
|
for(int i = 0; i < idxs_.at(arg).size(); i++){
|
||||||
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld);
|
unsigned id = i / min_vec;
|
||||||
// phase
|
|
||||||
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
|
|
||||||
phase = urem(phase, max_phase);
|
|
||||||
// off
|
|
||||||
Value* off_0 = idx[in_order[0]];
|
|
||||||
off_0 = udiv(off_0, i32(vector));
|
|
||||||
off_0 = xor_(off_0, phase);
|
|
||||||
off_0 = mul(off_0 , i32(vector));
|
|
||||||
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
|
|
||||||
Value* off = add(off_0, off_1);
|
|
||||||
//
|
|
||||||
shared.push_back(gep(shmems_[x], {off}));
|
|
||||||
}
|
|
||||||
//
|
|
||||||
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
|
|
||||||
auto idx = idxs_[ptrs][i];
|
|
||||||
// input ptr info
|
// input ptr info
|
||||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]);
|
int id_0 = id % (in_ld/min_vec);
|
||||||
Value *in_base = in_gep->getPointerOperand();
|
int id_1 = id / (in_ld/min_vec);
|
||||||
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector;
|
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
|
||||||
Value* out_base = shared[(i / per_thread_ld) % n_shared];
|
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
|
||||||
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]);
|
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||||
int out_off_1 = i % per_thread_ld;
|
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||||
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2;
|
if(tmp.find(key) == tmp.end()){
|
||||||
// asm
|
if(CurrBB != FirstBB)
|
||||||
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false);
|
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||||
std::string mod = (vector*2 == 16) ? ".cg" : ".ca";
|
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
|
||||||
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";";
|
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
|
||||||
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
|
phase = urem(phase, i32(max_phase));
|
||||||
call(iasm, {vals_[msks][idx], out_base, in_base});
|
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
|
||||||
|
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
|
||||||
|
off_0 = udiv(off_0, i32(min_vec));
|
||||||
|
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
|
||||||
|
off_0 = mul(off_0 , i32(min_vec));
|
||||||
|
Value* off = add(off_0, off_1);
|
||||||
|
if(CurrBB != FirstBB)
|
||||||
|
builder_->SetInsertPoint(CurrBB);
|
||||||
|
tmp[key] = gep(shmems_[x], {off});
|
||||||
|
}
|
||||||
|
shared.push_back({tmp[key], off});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
|
||||||
|
auto idx = idxs_[arg][i];
|
||||||
|
// input ptr info
|
||||||
|
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||||
|
Value *in_base = in_gep->getPointerOperand();
|
||||||
|
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||||
|
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
|
||||||
|
in_base = cst ? in_base : in_gep;
|
||||||
|
// output ptr info
|
||||||
|
Value* out_base = shared[i].first;
|
||||||
|
int out_off = shared[i].second*2;
|
||||||
|
// asm
|
||||||
|
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
|
||||||
|
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
|
||||||
|
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
|
||||||
|
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
|
||||||
|
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string asm_str = "cp.async.commit_group;";
|
||||||
|
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||||
|
call(iasm);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||||
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
|||||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||||
auto shapes = cts->get_type()->get_tile_shapes();
|
auto shapes = cts->get_type()->get_tile_shapes();
|
||||||
|
|
||||||
// default implementation
|
// store to shared
|
||||||
Value *current = nullptr;
|
Value *current = nullptr;
|
||||||
std::map<std::pair<int, int>, Value*> ptrs;
|
std::map<std::pair<int, int>, Value*> ptrs;
|
||||||
for(int i = 0; i < idxs_.at(arg).size(); i++){
|
for(int i = 0; i < idxs_.at(arg).size(); i++){
|
||||||
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
|
|||||||
add_barrier();
|
add_barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_async_wait_inst(ir::async_wait_inst*) {
|
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||||
std::string asm_str = "cp.async.wait_all;";
|
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
|
||||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||||
call(iasm);
|
call(iasm);
|
||||||
add_barrier();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||||
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
|||||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||||
Type *int_32_ty = Type::getInt32Ty(*ctx_);
|
Type *int_32_ty = Type::getInt32Ty(*ctx_);
|
||||||
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
|
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
|
||||||
Type *ptr_ty = ptr_ty(int_8_ty, 3);
|
Type *ptr_ty = ptr_ty(int_8_ty, 3);
|
||||||
GlobalVariable *sh_mem_array =
|
GlobalVariable *sh_mem_array =
|
||||||
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage,
|
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||||
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
||||||
}
|
}
|
||||||
|
@@ -15,114 +15,105 @@ namespace triton {
|
|||||||
namespace codegen{
|
namespace codegen{
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
|
||||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
|
||||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
|
||||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
|
||||||
return left_intersect || right_intersect;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
|
||||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
|
||||||
return intersect(X, y);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||||
if(!i)
|
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||||
return;
|
analysis::double_buffer_info_t* info = layout->get_double_buffer();
|
||||||
if(!i->get_type()->is_tile_ty())
|
if(info)
|
||||||
return;
|
return group_of(info->first, async_write);
|
||||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
std::vector<int> groups(phi->get_num_operands());
|
||||||
if(!layout)
|
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||||
return;
|
return *std::max_element(groups.begin(), groups.end());
|
||||||
if(alloc_->has_offset(layout)){
|
}
|
||||||
unsigned offset = alloc_->offset(layout);
|
else{
|
||||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||||
|
return std::distance(async_write.begin(), it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
|
||||||
for(ir::value *op: i->ops())
|
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||||
add_reference(op, res);
|
val_set_t ret;
|
||||||
|
for(ir::value* a: as){
|
||||||
|
if(!a->get_type()->is_tile_ty())
|
||||||
|
continue;
|
||||||
|
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||||
|
if(!a_layout)
|
||||||
|
continue;
|
||||||
|
int a_start = alloc_->offset(a_layout);
|
||||||
|
int a_end = a_start + a_layout->get_size();
|
||||||
|
for(ir::value* b: bs){
|
||||||
|
if(!b->get_type()->is_tile_ty())
|
||||||
|
continue;
|
||||||
|
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||||
|
if(!b_layout)
|
||||||
|
continue;
|
||||||
|
int b_start = alloc_->offset(b_layout);
|
||||||
|
int b_end = b_start + b_layout->get_size();
|
||||||
|
if(a_start < b_end || b_start < a_end)
|
||||||
|
ret.insert(b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
void membar::transfer(ir::basic_block *block,
|
||||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
val_vec_t& async_write,
|
||||||
add_reference(i, res);
|
val_set_t& sync_write,
|
||||||
}
|
val_set_t& sync_read,
|
||||||
|
std::set<ir::value*>& safe_war,
|
||||||
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
|
bool& inserted, ir::builder& builder) {
|
||||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||||
std::set<ir::value*> incoming;
|
for(ir::instruction *i: instructions){
|
||||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
if(dynamic_cast<ir::phi_node*>(i))
|
||||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
continue;
|
||||||
assert(inc_val);
|
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
|
||||||
if(incoming.insert(inc_val).second){
|
dynamic_cast<ir::masked_load_async_inst*>(i)){
|
||||||
ir::basic_block *block = inc_val->get_parent();
|
async_write.push_back(i);
|
||||||
builder.set_insert_point(block->get_inst_list().back());
|
}
|
||||||
if(type.first)
|
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||||
builder.create_async_wait();
|
sync_write.insert(i);
|
||||||
if(type.second)
|
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
|
||||||
builder.create_barrier();
|
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
|
||||||
|
// Get shared memory reads
|
||||||
|
std::set<ir::value*> read;
|
||||||
|
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||||
|
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
|
||||||
|
// RAW (async)
|
||||||
|
val_set_t tmp;
|
||||||
|
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||||
|
if(intersect_with(read, tmp).size()){
|
||||||
|
std::vector<int> groups(read.size());
|
||||||
|
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||||
|
int N = *std::max_element(groups.begin(), groups.end());
|
||||||
|
if(N < async_write.size()){
|
||||||
|
builder.set_insert_point(i);
|
||||||
|
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||||
|
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||||
|
inserted = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
// RAW, WAR
|
||||||
else {
|
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
|
||||||
builder.set_insert_point(instr);
|
builder.set_insert_point(i);
|
||||||
builder.create_barrier();
|
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||||
}
|
inserted = true;
|
||||||
}
|
|
||||||
|
|
||||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
|
||||||
membar::interval_vec_t result;
|
|
||||||
for(auto x: intervals)
|
|
||||||
for(interval_t i: x)
|
|
||||||
result.push_back(i);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<membar::interval_vec_t,
|
|
||||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
|
||||||
const interval_vec_t &written_to,
|
|
||||||
const interval_vec_t &read_from,
|
|
||||||
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
|
|
||||||
std::set<ir::value*>& safe_war,
|
|
||||||
std::vector<ir::instruction*>& to_sync) {
|
|
||||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
|
||||||
interval_vec_t new_written_to = written_to;
|
|
||||||
interval_vec_t new_read_from = read_from;
|
|
||||||
|
|
||||||
for(ir::instruction *i: instructions){
|
|
||||||
interval_vec_t read, written;
|
|
||||||
get_read_intervals(i, read);
|
|
||||||
get_written_intervals(i, written);
|
|
||||||
if(written.size())
|
|
||||||
to_sync.push_back(i);
|
|
||||||
bool read_after_write = intersect(new_written_to, read);
|
|
||||||
bool write_after_read = intersect(new_read_from, written);
|
|
||||||
// double buffering
|
|
||||||
if(safe_war.find(i) != safe_war.end()){
|
|
||||||
write_after_read = false;
|
|
||||||
read_after_write = false;
|
|
||||||
}
|
}
|
||||||
// record hazards
|
// update state of asynchronous copies
|
||||||
if(read_after_write || write_after_read) {
|
if(async_wait){
|
||||||
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
|
int N = async_write.size() - async_wait->get_N();
|
||||||
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
|
async_write.erase(async_write.begin(), async_write.begin() + N);
|
||||||
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
|
|
||||||
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
|
|
||||||
insert_loc.insert({i, {copy_async_wait, barrier}});
|
|
||||||
new_written_to.clear();
|
|
||||||
new_read_from.clear();
|
|
||||||
to_sync.clear();
|
|
||||||
}
|
}
|
||||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
// all the copy_to_shared and read from shared are synchronized after barrier
|
||||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
if(barrier){
|
||||||
|
sync_write.clear();
|
||||||
|
sync_read.clear();
|
||||||
|
}
|
||||||
|
sync_read.insert(read.begin(), read.end());
|
||||||
|
|
||||||
}
|
}
|
||||||
return std::make_pair(new_written_to, new_read_from);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void membar::run(ir::module &mod) {
|
void membar::run(ir::module &mod) {
|
||||||
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
|
|||||||
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||||
std::vector<ir::instruction*> to_sync;
|
std::map<ir::basic_block*, val_set_t> sync_reads;
|
||||||
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
|
std::list<ir::value *> pipelined;
|
||||||
size_t n_inserted_im1 = 0;
|
bool inserted;
|
||||||
bool done = false;
|
|
||||||
do{
|
do{
|
||||||
|
inserted = false;
|
||||||
// find barrier location
|
// find barrier location
|
||||||
for(ir::basic_block *block: rpo){
|
for(ir::basic_block *block: rpo){
|
||||||
// written to
|
// join inputs
|
||||||
std::vector<interval_vec_t> pred_written_to;
|
val_vec_t async_write;
|
||||||
for(ir::basic_block* pred: block->get_predecessors())
|
val_set_t sync_write;
|
||||||
pred_written_to.push_back(written_to[pred]);
|
val_set_t sync_read;
|
||||||
// read from
|
val_set_t tmp;
|
||||||
std::vector<interval_vec_t> pred_read_from;
|
for(ir::basic_block* pred: block->get_predecessors()){
|
||||||
for(ir::basic_block* pred: block->get_predecessors())
|
for(ir::value* v: async_writes[pred])
|
||||||
pred_read_from.push_back(read_from[pred]);
|
if(tmp.insert(v).second)
|
||||||
// apply transfer function
|
async_write.push_back(v);
|
||||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
|
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
|
||||||
written_to[block] = result.first;
|
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
|
||||||
read_from[block] = result.second;
|
}
|
||||||
|
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
|
||||||
|
async_writes[block] = async_write;
|
||||||
|
sync_writes[block] = sync_write;
|
||||||
|
sync_reads[block] = sync_read;
|
||||||
}
|
}
|
||||||
size_t n_inserted_i = insert_locs.size();
|
}while(inserted);
|
||||||
done = (n_inserted_im1 == n_inserted_i);
|
|
||||||
n_inserted_im1 = n_inserted_i;
|
|
||||||
}while(!done);
|
|
||||||
for(auto x: insert_locs){
|
|
||||||
insert_barrier(x.first, x.second, builder);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
|
#include "triton/codegen/analysis/layout.h"
|
||||||
|
|
||||||
namespace triton {
|
namespace triton {
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
|
|||||||
ir::value *ptr = ld->get_pointer_operand();
|
ir::value *ptr = ld->get_pointer_operand();
|
||||||
ir::value *msk = ld->get_mask_operand();
|
ir::value *msk = ld->get_mask_operand();
|
||||||
ir::value *val = ld->get_false_value_operand();
|
ir::value *val = ld->get_false_value_operand();
|
||||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||||
copy_to_shared->replace_all_uses_with(new_load);
|
int nts = layout->nts(layout->get_order()[0]);
|
||||||
return true;
|
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||||
|
if(nts*dtsize >= 4){
|
||||||
|
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||||
|
copy_to_shared->replace_all_uses_with(new_load);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||||
|
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
|
||||||
|
// return true;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
|
|||||||
bool was_modified = false;
|
bool was_modified = false;
|
||||||
was_modified = was_modified || rewrite_mult(i, builder);
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
// if(tgt_->as_nvidia()->sm() >= 80)
|
if(tgt_->as_nvidia()->sm() >= 80)
|
||||||
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||||
if(was_modified)
|
if(was_modified)
|
||||||
seen.insert(i);
|
seen.insert(i);
|
||||||
}
|
}
|
||||||
|
116
lib/codegen/transform/pipeline.cc
Normal file
116
lib/codegen/transform/pipeline.cc
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "triton/codegen/transform/pipeline.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/basic_block.h"
|
||||||
|
#include "triton/ir/instructions.h"
|
||||||
|
#include "triton/ir/utils.h"
|
||||||
|
|
||||||
|
namespace triton {
|
||||||
|
namespace codegen{
|
||||||
|
namespace transform{
|
||||||
|
|
||||||
|
|
||||||
|
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
|
||||||
|
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||||
|
if(!i || i->get_parent() != block)
|
||||||
|
return;
|
||||||
|
if(i->get_id()==ir::INST_PHI)
|
||||||
|
return;
|
||||||
|
ret.push_back(i);
|
||||||
|
for(ir::user* u: i->get_users())
|
||||||
|
recursive_deps(u, block, ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pipeline::run(ir::module &mod) {
|
||||||
|
// *Very* conservative heuristics for pre-fetching.
|
||||||
|
// A load instruction can be pipelined if:
|
||||||
|
// - the pointer is a phi node that references a value
|
||||||
|
// in its basic block (i.e., pointer induction variable)
|
||||||
|
// - the load has only a single use in a dot instruction
|
||||||
|
// As more use cases become apparent, this pass will be improved
|
||||||
|
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
||||||
|
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||||
|
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||||
|
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||||
|
auto users = load->get_users();
|
||||||
|
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||||
|
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
||||||
|
to_pipeline.push_back({load, ptr});
|
||||||
|
}});
|
||||||
|
// do the pipelining
|
||||||
|
std::vector<ir::phi_node*> new_loads;
|
||||||
|
ir::builder &builder = mod.get_builder();
|
||||||
|
for(auto info: to_pipeline){
|
||||||
|
ir::load_inst* load = info.first;
|
||||||
|
ir::phi_node* ptr = info.second;
|
||||||
|
ir::basic_block* block = load->get_parent();
|
||||||
|
ir::basic_block* header = block->get_predecessors()[0];
|
||||||
|
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||||
|
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
|
||||||
|
assert(block_br);
|
||||||
|
assert(header_br);
|
||||||
|
ir::type* ty = load->get_type();
|
||||||
|
// pre-fetch first iteration
|
||||||
|
builder.set_insert_point(header->get_inst_list().back());
|
||||||
|
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||||
|
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||||
|
ir::value* false_value;
|
||||||
|
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||||
|
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||||
|
false_value = masked_load->get_false_value_operand();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||||
|
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||||
|
// pre-fetch next iteration
|
||||||
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
|
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||||
|
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||||
|
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||||
|
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||||
|
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||||
|
// phi node
|
||||||
|
builder.set_insert_point(block->get_first_non_phi());
|
||||||
|
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||||
|
new_load->add_incoming(first_load, header);
|
||||||
|
new_load->add_incoming(next_load, block);
|
||||||
|
load->replace_all_uses_with(new_load);
|
||||||
|
new_loads.push_back(new_load);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// try to move dot_inst after loads
|
||||||
|
// for better overlap of io and compute
|
||||||
|
struct move_config_t{
|
||||||
|
std::vector<ir::instruction*> insts;
|
||||||
|
ir::load_inst* dst;
|
||||||
|
};
|
||||||
|
std::map<ir::basic_block*, move_config_t> to_move;
|
||||||
|
|
||||||
|
if(has_copy_async_){
|
||||||
|
for(ir::function* fn: mod.get_function_list())
|
||||||
|
for(ir::basic_block* bb: fn->blocks())
|
||||||
|
for(ir::instruction* inst: bb->get_inst_list()){
|
||||||
|
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
||||||
|
recursive_deps(i, bb, to_move[bb].insts);
|
||||||
|
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
||||||
|
to_move[bb].dst = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(auto& x: to_move){
|
||||||
|
builder.set_insert_point_after(x.second.dst);
|
||||||
|
for(ir::instruction* i: x.second.insts){
|
||||||
|
x.first->erase(i);
|
||||||
|
builder.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
|||||||
inline bool is_cst(ir::value *x) {
|
inline bool is_cst(ir::value *x) {
|
||||||
if(dynamic_cast<ir::constant*>(x))
|
if(dynamic_cast<ir::constant*>(x))
|
||||||
return true;
|
return true;
|
||||||
|
if(dynamic_cast<ir::make_range*>(x))
|
||||||
|
return true;
|
||||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||||
return is_cst(v->get_operand(0));
|
return is_cst(v->get_operand(0));
|
||||||
return false;
|
return false;
|
||||||
|
@@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
|
|||||||
|
|
||||||
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
|
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
|
||||||
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
|
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
|
||||||
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
||||||
|
// properties
|
||||||
|
int shared_total, shared_optin, shared_static;
|
||||||
|
int n_spills, n_reg;
|
||||||
|
CUdevice dev;
|
||||||
|
dispatch::cuCtxGetDevice(&dev);
|
||||||
|
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
||||||
|
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||||
|
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||||
|
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||||
|
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||||
|
if (shared_optin > 49152){
|
||||||
|
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||||
|
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
|
|
||||||
void cu_module::init_from_ptx(const std::string& ptx) {
|
void cu_module::init_from_ptx(const std::string& ptx) {
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
|
// std::cout << ptx << std::endl;
|
||||||
|
|
||||||
try{
|
try{
|
||||||
// // compile ptx with ptxas
|
// // compile ptx with ptxas
|
||||||
|
@@ -76,7 +76,7 @@ void host_stream::synchronize() {
|
|||||||
hst_->args.clear();
|
hst_->args.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
|
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
|
||||||
auto hst = kernel->module()->hst();
|
auto hst = kernel->module()->hst();
|
||||||
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
|
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
|
||||||
char* params = new char[args_size];
|
char* params = new char[args_size];
|
||||||
@@ -113,13 +113,13 @@ void cu_stream::synchronize() {
|
|||||||
dispatch::cuStreamSynchronize(*cu_);
|
dispatch::cuStreamSynchronize(*cu_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
|
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
|
||||||
void *config[] = {
|
void *config[] = {
|
||||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||||
CU_LAUNCH_PARAM_END
|
CU_LAUNCH_PARAM_END
|
||||||
};
|
};
|
||||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config);
|
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||||
|
@@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
|
|||||||
// convenience functions
|
// convenience functions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
value *builder::get_int1(bool val)
|
||||||
|
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
|
||||||
|
|
||||||
value *builder::get_int32(int32_t val)
|
value *builder::get_int32(int32_t val)
|
||||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||||
|
|
||||||
@@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) {
|
|||||||
return insert(barrier_inst::create(ctx_, name));
|
return insert(barrier_inst::create(ctx_, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_async_wait() {
|
value *builder::create_async_wait(int N) {
|
||||||
return insert(async_wait_inst::create(ctx_));
|
return insert(async_wait_inst::create(ctx_, N));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
|
|||||||
blocks_.reserve(num_reserved);
|
blocks_.reserve(num_reserved);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value* phi_node::get_value_for_block(basic_block * block) {
|
||||||
|
auto it = std::find(blocks_.begin(), blocks_.end(), block);
|
||||||
|
size_t n = std::distance(blocks_.begin(), it);
|
||||||
|
return get_incoming_value(n);
|
||||||
|
}
|
||||||
|
|
||||||
// Set incoming value
|
// Set incoming value
|
||||||
void phi_node::set_incoming_value(unsigned i, value *v){
|
void phi_node::set_incoming_value(unsigned i, value *v){
|
||||||
assert(v && "PHI node got a null value!");
|
assert(v && "PHI node got a null value!");
|
||||||
@@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
|||||||
return new barrier_inst(ctx, name, next);
|
return new barrier_inst(ctx, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
|
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
|
||||||
instruction *next)
|
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
|
||||||
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
|
|
||||||
|
|
||||||
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
|
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
|
||||||
return new async_wait_inst(ctx, name, next);
|
return new async_wait_inst(ctx, N, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -15,10 +15,10 @@
|
|||||||
#include "triton/codegen/transform/peephole.h"
|
#include "triton/codegen/transform/peephole.h"
|
||||||
#include "triton/codegen/transform/membar.h"
|
#include "triton/codegen/transform/membar.h"
|
||||||
#include "triton/codegen/transform/reassociate.h"
|
#include "triton/codegen/transform/reassociate.h"
|
||||||
#include "triton/codegen/transform/reorder.h"
|
|
||||||
#include "triton/codegen/transform/cts.h"
|
#include "triton/codegen/transform/cts.h"
|
||||||
#include "triton/codegen/transform/disassociate.h"
|
#include "triton/codegen/transform/disassociate.h"
|
||||||
#include "triton/codegen/selection/generator.h"
|
#include "triton/codegen/selection/generator.h"
|
||||||
|
#include "triton/codegen/transform/pipeline.h"
|
||||||
#include "triton/runtime/function.h"
|
#include "triton/runtime/function.h"
|
||||||
#include "triton/lang/cpp.h"
|
#include "triton/lang/cpp.h"
|
||||||
#include "triton/lang/parser.h"
|
#include "triton/lang/parser.h"
|
||||||
@@ -149,6 +149,7 @@ void kernel::init_ker(){
|
|||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::transform::cts cts(cts_use_async);
|
codegen::transform::cts cts(cts_use_async);
|
||||||
|
codegen::transform::pipeline pipeline(cts_use_async);
|
||||||
codegen::transform::disassociate disassociate;
|
codegen::transform::disassociate disassociate;
|
||||||
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
|
||||||
codegen::analysis::liveness liveness(&layouts);
|
codegen::analysis::liveness liveness(&layouts);
|
||||||
@@ -156,19 +157,24 @@ void kernel::init_ker(){
|
|||||||
codegen::analysis::allocation allocation(&liveness);
|
codegen::analysis::allocation allocation(&liveness);
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole(target.get());
|
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||||
codegen::transform::reassociate reassociate;
|
codegen::transform::reassociate reassociate;
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
dce.run(*ir_);
|
dce.run(*ir_);
|
||||||
|
pipeline.run(*ir_);
|
||||||
|
dce.run(*ir_);
|
||||||
disassociate.run(*ir_);
|
disassociate.run(*ir_);
|
||||||
dce.run(*ir_);
|
dce.run(*ir_);
|
||||||
|
align.run(*ir_);
|
||||||
|
axes.run(*ir_);
|
||||||
|
layouts.run(*ir_);
|
||||||
peephole.run(*ir_);
|
peephole.run(*ir_);
|
||||||
dce.run(*ir_);
|
dce.run(*ir_);
|
||||||
align.run(*ir_);
|
|
||||||
if(target->is_gpu())
|
if(target->is_gpu())
|
||||||
cts.run(*ir_);
|
cts.run(*ir_);
|
||||||
|
align.run(*ir_);
|
||||||
axes.run(*ir_);
|
axes.run(*ir_);
|
||||||
layouts.run(*ir_);
|
layouts.run(*ir_);
|
||||||
coalesce.run(*ir_);
|
coalesce.run(*ir_);
|
||||||
@@ -179,6 +185,11 @@ void kernel::init_ker(){
|
|||||||
reassociate.run(*ir_);
|
reassociate.run(*ir_);
|
||||||
cts.run(*ir_);
|
cts.run(*ir_);
|
||||||
}
|
}
|
||||||
|
dce.run(*ir_);
|
||||||
|
// ir::print(*ir_, std::cout);
|
||||||
|
align.run(*ir_);
|
||||||
|
axes.run(*ir_);
|
||||||
|
layouts.run(*ir_);
|
||||||
peephole.run(*ir_);
|
peephole.run(*ir_);
|
||||||
dce.run(*ir_);
|
dce.run(*ir_);
|
||||||
align.run(*ir_);
|
align.run(*ir_);
|
||||||
@@ -187,8 +198,9 @@ void kernel::init_ker(){
|
|||||||
swizzle.run(*ir_);
|
swizzle.run(*ir_);
|
||||||
liveness.run(*ir_);
|
liveness.run(*ir_);
|
||||||
allocation.run(*ir_);
|
allocation.run(*ir_);
|
||||||
if(allocation.allocated_size() > dev_->max_shared_memory())
|
shared_mem_ = allocation.allocated_size();
|
||||||
throw exception::out_of_shared_memory();
|
// if(allocation.allocated_size() > dev_->max_shared_memory())
|
||||||
|
// throw exception::out_of_shared_memory();
|
||||||
barriers.run(*ir_);
|
barriers.run(*ir_);
|
||||||
isel.visit(*ir_, *llvm);
|
isel.visit(*ir_, *llvm);
|
||||||
//if(res->spilled() > 256)
|
//if(res->spilled() > 256)
|
||||||
@@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
|
|||||||
for(size_t i = 0; i < 3; i++)
|
for(size_t i = 0; i < 3; i++)
|
||||||
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
|
||||||
// enqueue
|
// enqueue
|
||||||
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
|
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string kernel::get_asm(asm_mode_t mode) {
|
std::string kernel::get_asm(asm_mode_t mode) {
|
||||||
@@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
|
|||||||
while(grid.size() < 3)
|
while(grid.size() < 3)
|
||||||
grid.push_back(1);
|
grid.push_back(1);
|
||||||
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
|
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
|
||||||
stream, true);
|
stream, 5, 20);
|
||||||
ret = (ts < best_ts) ? current : ret;
|
ret = (ts < best_ts) ? current : ret;
|
||||||
best_ts = std::min(ts, best_ts);
|
best_ts = std::min(ts, best_ts);
|
||||||
}
|
}
|
||||||
|
@@ -2,58 +2,74 @@ import triton
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# square benchmarks
|
# square benchmarks
|
||||||
nt = {False: 'n', True: 't'}
|
nt = {False: "n", True: "t"}
|
||||||
square_confs = [
|
square_confs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names = ['M', 'N', 'K'],
|
x_names=["M", "N", "K"],
|
||||||
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
|
x_vals=[512 * i for i in range(1, 16)],
|
||||||
y_name = 'provider',
|
y_name="provider",
|
||||||
y_vals = ['torch', 'triton', 'cutlass'],
|
y_vals=["torch", "triton", "cutlass"],
|
||||||
y_lines = ['Torch', 'Triton', 'CUTLASS'],
|
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||||
ylabel = 'TFLOPS',
|
ylabel="TFLOPS",
|
||||||
loglog = False,
|
loglog=False,
|
||||||
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
|
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||||
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
|
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
||||||
)\
|
) for AT in [False, True] for BT in [False, True]
|
||||||
for AT in [False, True] for BT in [False, True]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@triton.testing.perf_report(square_confs)
|
@triton.testing.perf_report(square_confs)
|
||||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
|
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
|
||||||
import os
|
import os
|
||||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
|
|
||||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
|
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||||
if AT: a = a.t()
|
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||||
if BT: b = b.t()
|
if AT:
|
||||||
|
a = a.t()
|
||||||
|
if BT:
|
||||||
|
b = b.t()
|
||||||
num_flops = 2 * M * N * K
|
num_flops = 2 * M * N * K
|
||||||
if provider == 'torch':
|
if provider == "torch":
|
||||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
torch_tflops = num_flops / torch_ms * 1e-9
|
torch_tflops = num_flops / torch_ms * 1e-9
|
||||||
return torch_tflops
|
return torch_tflops
|
||||||
if provider == 'triton':
|
if provider == "triton":
|
||||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
triton_tflops = num_flops / triton_ms * 1e-9
|
triton_tflops = num_flops / triton_ms * 1e-9
|
||||||
return triton_tflops
|
return triton_tflops
|
||||||
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
|
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# run program specified by CUTLASS_PROFILER env variable
|
# run program specified by CUTLASS_PROFILER env variable
|
||||||
layout_a = 'column' if AT else 'row'
|
layout_a = "column" if AT else "row"
|
||||||
layout_b = 'column' if BT else 'row'
|
layout_b = "column" if BT else "row"
|
||||||
# create temporary file name
|
# create temporary file name
|
||||||
fd, fname = tempfile.mkstemp()
|
fd, fname = tempfile.mkstemp()
|
||||||
# run program and gets its output
|
# run program and gets its output
|
||||||
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
|
cmd = [
|
||||||
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
|
os.environ["CUTLASS_PROFILER"],
|
||||||
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
|
f"--m={M}",
|
||||||
|
f"--n={N}",
|
||||||
|
f"--k={K}",
|
||||||
|
f"--A=f16:{layout_a}",
|
||||||
|
f"--B=f16:{layout_b}",
|
||||||
|
"--C=f16:column",
|
||||||
|
"--accum=f32",
|
||||||
|
"--operation=gemm",
|
||||||
|
"--verification-enabled=false",
|
||||||
|
f"--warmup-iterations={warmup}",
|
||||||
|
f"--profiling-iterations={rep}",
|
||||||
|
f"--output={fname}",
|
||||||
|
"--verbose=false",
|
||||||
|
]
|
||||||
# run cmd
|
# run cmd
|
||||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||||
# read CSV output
|
# read CSV output
|
||||||
df_c = pd.read_csv(f'{fname}.gemm.csv')
|
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||||
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
|
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
||||||
return cutlass_tflops
|
return cutlass_tflops
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
bench_op.run()
|
bench_op.run()
|
||||||
|
102
python/setup.py
102
python/setup.py
@@ -15,21 +15,21 @@ import distutils.spawn
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def find_llvm():
|
def find_llvm():
|
||||||
versions = ['-10', '-10.0', '']
|
versions = ["-10", "-10.0", ""]
|
||||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
supported = ["llvm-config{v}".format(v=v) for v in versions]
|
||||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||||
paths = [p for p in paths if p is not None]
|
paths = [p for p in paths if p is not None]
|
||||||
if paths:
|
if paths:
|
||||||
return paths[0]
|
return paths[0]
|
||||||
config = distutils.spawn.find_executable('llvm-config')
|
config = distutils.spawn.find_executable("llvm-config")
|
||||||
instructions = 'Please install llvm-10-dev'
|
instructions = "Please install llvm-10-dev"
|
||||||
if config is None:
|
if config is None:
|
||||||
raise RuntimeError('Could not find llvm-config. ' + instructions)
|
raise RuntimeError("Could not find llvm-config. " + instructions)
|
||||||
version = os.popen('{config} --version'.format(config=config)).read()
|
version = os.popen("{config} --version".format(config=config)).read()
|
||||||
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
|
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
def __init__(self, name, path, sourcedir=''):
|
def __init__(self, name, path, sourcedir=""):
|
||||||
Extension.__init__(self, name, sources=[])
|
Extension.__init__(self, name, sources=[])
|
||||||
self.sourcedir = os.path.abspath(sourcedir)
|
self.sourcedir = os.path.abspath(sourcedir)
|
||||||
self.path = path
|
self.path = path
|
||||||
@@ -37,84 +37,84 @@ class CMakeExtension(Extension):
|
|||||||
class CMakeBuild(build_ext):
|
class CMakeBuild(build_ext):
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(['cmake', '--version'])
|
out = subprocess.check_output(["cmake", "--version"])
|
||||||
except OSError:
|
except OSError:
|
||||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||||
", ".join(e.name for e in self.extensions))
|
", ".join(e.name for e in self.extensions))
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
|
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
|
||||||
if cmake_version < '3.1.0':
|
if cmake_version < "3.1.0":
|
||||||
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
||||||
|
|
||||||
for ext in self.extensions:
|
for ext in self.extensions:
|
||||||
self.build_extension(ext)
|
self.build_extension(ext)
|
||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
#self.debug = True
|
# self.debug = True
|
||||||
|
self.debug = False
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# python directories
|
# python directories
|
||||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
|
||||||
torch_include_dirs = include_paths(True)
|
torch_include_dirs = include_paths(True)
|
||||||
torch_library_dirs = library_paths(True)
|
torch_library_dirs = library_paths(True)
|
||||||
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
|
||||||
cmake_args = [
|
cmake_args = [
|
||||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||||
'-DBUILD_TUTORIALS=OFF',
|
"-DBUILD_TUTORIALS=OFF",
|
||||||
'-DBUILD_PYTHON_MODULE=ON',
|
"-DBUILD_PYTHON_MODULE=ON",
|
||||||
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
#'-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||||
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
|
||||||
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
|
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
|
||||||
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
|
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
|
||||||
'-DTORCH_CXX11_ABI=' + cxx11abi,
|
"-DTORCH_CXX11_ABI=" + cxx11abi,
|
||||||
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
|
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
|
||||||
'-DLLVM_CONFIG=' + find_llvm()
|
"-DLLVM_CONFIG=" + find_llvm(),
|
||||||
]
|
]
|
||||||
# configuration
|
# configuration
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = "Debug" if self.debug else "Release"
|
||||||
cfg = 'Release'
|
build_args = ["--config", cfg]
|
||||||
build_args = ['--config', cfg]
|
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
|
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
|
||||||
if sys.maxsize > 2**32:
|
if sys.maxsize > 2**32:
|
||||||
cmake_args += ['-A', 'x64']
|
cmake_args += ["-A", "x64"]
|
||||||
build_args += ['--', '/m']
|
build_args += ["--", "/m"]
|
||||||
else:
|
else:
|
||||||
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
|
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
|
||||||
build_args += ['--', '-j4']
|
build_args += ["--", "-j4"]
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
if not os.path.exists(self.build_temp):
|
if not os.path.exists(self.build_temp):
|
||||||
os.makedirs(self.build_temp)
|
os.makedirs(self.build_temp)
|
||||||
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
|
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
|
||||||
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='triton',
|
name="triton",
|
||||||
version='1.0.0',
|
version="1.0.0",
|
||||||
author='Philippe Tillet',
|
author="Philippe Tillet",
|
||||||
author_email='phil@openai.com',
|
author_email="phil@openai.com",
|
||||||
description='A language and compiler for custom Deep Learning operations',
|
description="A language and compiler for custom Deep Learning operations",
|
||||||
long_description='',
|
long_description="",
|
||||||
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
|
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
|
||||||
install_requires=['numpy', 'torch'],
|
install_requires=["numpy", "torch"],
|
||||||
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
|
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
ext_modules=[CMakeExtension("triton", "triton/_C/")],
|
||||||
cmdclass={'build_ext': CMakeBuild},
|
cmdclass={"build_ext": CMakeBuild},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
# for PyPI
|
# for PyPI
|
||||||
keywords=['Compiler', 'Deep Learning'],
|
keywords=["Compiler", "Deep Learning"],
|
||||||
url='https://github.com/ptillet/triton/',
|
url="https://github.com/ptillet/triton/",
|
||||||
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
|
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
|
||||||
'Intended Audience :: Developers', # Define that your audience are developers
|
"Intended Audience :: Developers", # Define that your audience are developers
|
||||||
'Topic :: Software Development :: Build Tools',
|
"Topic :: Software Development :: Build Tools",
|
||||||
'License :: OSI Approved :: MIT License', # Again, pick a license
|
"License :: OSI Approved :: MIT License", # Again, pick a license
|
||||||
'Programming Language :: Python :: 3.6',
|
"Programming Language :: Python :: 3.6",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@@ -2,29 +2,17 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"MODE, TRANS_A, TRANS_B, BLOCK",
|
"MODE, TRANS_A, TRANS_B, BLOCK",
|
||||||
[
|
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
|
||||||
(mode, at, bt, block)
|
for block in [16, 32, 64]],
|
||||||
for mode in ["sdd", "dsd", "dds"]
|
|
||||||
for at in [False, True]
|
|
||||||
for bt in [False, True]
|
|
||||||
for block in [16, 32, 64]
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
def test_matmul(
|
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||||
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
|
|
||||||
):
|
|
||||||
# set seed
|
# set seed
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
# create inputs
|
# create inputs
|
||||||
a = torch.randn(
|
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
|
||||||
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
|
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
|
||||||
)
|
|
||||||
b = torch.randn(
|
|
||||||
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
|
|
||||||
)
|
|
||||||
shape = {
|
shape = {
|
||||||
"sdd": (M, N),
|
"sdd": (M, N),
|
||||||
"dsd": (a.shape[2], a.shape[3]),
|
"dsd": (a.shape[2], a.shape[3]),
|
||||||
@@ -32,9 +20,7 @@ def test_matmul(
|
|||||||
}[MODE]
|
}[MODE]
|
||||||
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
|
||||||
# triton result
|
# triton result
|
||||||
op = triton.ops.blocksparse.matmul(
|
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||||
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
|
|
||||||
)
|
|
||||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||||
rc = op(ra, rb)
|
rc = op(ra, rb)
|
||||||
@@ -49,7 +35,6 @@ def test_matmul(
|
|||||||
# compare
|
# compare
|
||||||
assert triton.testing.allclose(rc, tc)
|
assert triton.testing.allclose(rc, tc)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"BLOCK, WIDTH",
|
"BLOCK, WIDTH",
|
||||||
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
|
||||||
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
|||||||
# create inputs
|
# create inputs
|
||||||
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
|
||||||
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
|
||||||
at_mask = torch.randint(
|
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
|
||||||
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
|
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
|
||||||
)
|
|
||||||
kp_mask = torch.randint(
|
|
||||||
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
|
|
||||||
)
|
|
||||||
kp_mask[kp_mask == 1.0] = float("-inf")
|
kp_mask[kp_mask == 1.0] = float("-inf")
|
||||||
# triton result
|
# triton result
|
||||||
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
op = triton.ops.blocksparse.softmax(layout, BLOCK)
|
||||||
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
|
|||||||
# compare
|
# compare
|
||||||
assert triton.testing.allclose(ry, ty)
|
assert triton.testing.allclose(ry, ty)
|
||||||
|
|
||||||
|
|
||||||
def test_attention_fwd_bwd(
|
def test_attention_fwd_bwd(
|
||||||
input_scale=1.0,
|
input_scale=1.0,
|
||||||
tol=2e-2,
|
tol=2e-2,
|
||||||
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
|
|||||||
# inputs
|
# inputs
|
||||||
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
qkv_shape = (batch_size, n_heads, n_ctx, 64)
|
||||||
qkvs = [
|
qkvs = [
|
||||||
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
|
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
|
||||||
.to(dtype)
|
|
||||||
.cuda()
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
]
|
||||||
attn_mask = torch.tril(
|
attn_mask = torch.tril(
|
||||||
torch.ones(
|
torch.ones(
|
||||||
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
|
|||||||
query.retain_grad()
|
query.retain_grad()
|
||||||
key.retain_grad()
|
key.retain_grad()
|
||||||
value.retain_grad()
|
value.retain_grad()
|
||||||
attn_out = triton_attention(
|
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
|
||||||
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
|
|
||||||
)
|
|
||||||
# ad hoc loss
|
# ad hoc loss
|
||||||
loss = (attn_out ** 2).mean()
|
loss = (attn_out**2).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grads = [query.grad, key.grad, value.grad]
|
grads = [query.grad, key.grad, value.grad]
|
||||||
|
|
||||||
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
|
|||||||
probs = torch.softmax(scores, dim=-1)
|
probs = torch.softmax(scores, dim=-1)
|
||||||
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
|
||||||
# ad hoc loss
|
# ad hoc loss
|
||||||
torch_loss = (torch_attn_out ** 2).mean()
|
torch_loss = (torch_attn_out**2).mean()
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
|
||||||
|
|
||||||
# comparison
|
# comparison
|
||||||
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
|
||||||
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
|
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
|
||||||
for g1, g2 in zip(grads, torch_grads):
|
for g1, g2 in zip(grads, torch_grads):
|
||||||
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
|
||||||
def triton_attention(
|
def triton_attention(
|
||||||
layout,
|
layout,
|
||||||
block: int,
|
block: int,
|
||||||
@@ -168,12 +142,8 @@ def triton_attention(
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
|
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
|
||||||
layout, block, "sdd", trans_a=False, trans_b=True
|
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
|
||||||
)
|
|
||||||
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
|
|
||||||
layout, block, "dsd", trans_a=False, trans_b=False
|
|
||||||
)
|
|
||||||
sparse_softmax = triton.ops.blocksparse.softmax(
|
sparse_softmax = triton.ops.blocksparse.softmax(
|
||||||
layout,
|
layout,
|
||||||
block,
|
block,
|
||||||
|
@@ -4,7 +4,7 @@ import triton
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
|
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
||||||
itertools.chain(*[
|
itertools.chain(*[
|
||||||
[
|
[
|
||||||
# 1 warp
|
# 1 warp
|
||||||
@@ -17,14 +17,14 @@ import torch
|
|||||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
||||||
# 2 warp
|
# # 2 warp
|
||||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# 4 warp
|
# # 4 warp
|
||||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
||||||
@@ -40,24 +40,28 @@ import torch
|
|||||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
||||||
# variable input
|
# variable input
|
||||||
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
|
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||||
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
|
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
|
||||||
]))
|
]),
|
||||||
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
)
|
||||||
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||||
|
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
triton.ops._matmul._kernels = dict()
|
triton.ops._matmul._kernels = dict()
|
||||||
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
|
||||||
if M is None: M = TM
|
if M is None:
|
||||||
if N is None: N = TN
|
M = TM
|
||||||
if K is None: K = TK * TZ
|
if N is None:
|
||||||
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
N = TN
|
||||||
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
if K is None:
|
||||||
|
K = TK * SPLITK
|
||||||
|
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
||||||
|
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
||||||
a = a.t() if AT else a
|
a = a.t() if AT else a
|
||||||
b = b.t() if BT else b
|
b = b.t() if BT else b
|
||||||
th_c = torch.matmul(a, b)
|
th_c = torch.matmul(a, b)
|
||||||
tt_c = triton.ops.matmul(a, b)
|
tt_c = triton.ops.matmul(a, b)
|
||||||
assert triton.testing.allclose(th_c, tt_c)
|
assert triton.testing.allclose(th_c, tt_c)
|
||||||
|
@@ -1,198 +1,199 @@
|
|||||||
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
|
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
|
||||||
TYPE* B __readonly __noalias __aligned(16),
|
TYPE *B __readonly __noalias __aligned(16),
|
||||||
TYPE* C __noalias __aligned(16),
|
TYPE *C __noalias __aligned(16),
|
||||||
int lda __multipleof(8),
|
int lda __multipleof(8),
|
||||||
int ldb __multipleof(8),
|
int ldb __multipleof(8),
|
||||||
int ldc __multipleof(8),
|
int ldc __multipleof(8),
|
||||||
long stride_za __multipleof(8),
|
long stride_za __multipleof(8),
|
||||||
long stride_zb __multipleof(8),
|
long stride_zb __multipleof(8),
|
||||||
long stride_zc __multipleof(8),
|
long stride_zc __multipleof(8),
|
||||||
long stride_ha __multipleof(8),
|
long stride_ha __multipleof(8),
|
||||||
long stride_hb __multipleof(8),
|
long stride_hb __multipleof(8),
|
||||||
long stride_hc __multipleof(8),
|
long stride_hc __multipleof(8),
|
||||||
int DS0, int DS1,
|
int DS0, int DS1,
|
||||||
int SDD_K __multipleof(16),
|
int SDD_K __multipleof(16),
|
||||||
int SDD_off_width,
|
int SDD_off_width,
|
||||||
int* lut, int* locks, int nlocks) {
|
int *lut, int *locks, int nlocks) {
|
||||||
/* ---------------- */
|
/* ---------------- */
|
||||||
/* Prologue */
|
/* Prologue */
|
||||||
/* ---------------- */
|
/* ---------------- */
|
||||||
// program ids
|
// program ids
|
||||||
int pid0 = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int pid1 = get_program_id(1);
|
int pid1 = get_program_id(1);
|
||||||
int pidz = get_program_id(2);
|
int pidz = get_program_id(2);
|
||||||
#ifdef SDD
|
#ifdef SDD
|
||||||
// load LUT header
|
// load LUT header
|
||||||
pid1 = pid1 + SDD_off_width;
|
pid1 = pid1 + SDD_off_width;
|
||||||
int blockidm[TM] = (0 ... TM) / BLOCK;
|
int blockidm[TM] = (0 ... TM) / BLOCK;
|
||||||
int blockidn[TN] = (0 ... TN) / BLOCK;
|
int blockidn[TN] = (0 ... TN) / BLOCK;
|
||||||
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
|
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
|
||||||
int offlutn[TN] = blockidn*4;
|
int offlutn[TN] = blockidn * 4;
|
||||||
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
|
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
|
||||||
int z = *(header + 0);
|
int z = *(header + 0);
|
||||||
int i[TM] = *(header + 1 + offlutm);
|
int i[TM] = *(header + 1 + offlutm);
|
||||||
int j[TN] = *(header + 2 + offlutn);
|
int j[TN] = *(header + 2 + offlutn);
|
||||||
int AS1 = SDD_K / TZ;
|
int AS1 = SDD_K / TZ;
|
||||||
int lockid = select(TZ > 1, 1, 0);
|
int lockid = select(TZ > 1, 1, 0);
|
||||||
int offka = pid0 * AS1;
|
int offka = pid0 * AS1;
|
||||||
int offkb = pid0 * AS1;
|
int offkb = pid0 * AS1;
|
||||||
int offmc = 0;
|
int offmc = 0;
|
||||||
int offnc = 0;
|
int offnc = 0;
|
||||||
int offpa = 0;
|
int offpa = 0;
|
||||||
int offpb = 0;
|
int offpb = 0;
|
||||||
int maxid = TZ;
|
int maxid = TZ;
|
||||||
int offhc = 0;
|
int offhc = 0;
|
||||||
int offha = z;
|
int offha = z;
|
||||||
int offhb = z;
|
int offhb = z;
|
||||||
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
|
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
|
||||||
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
|
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
|
||||||
#else
|
#else
|
||||||
// load LUT header
|
// load LUT header
|
||||||
int *header = lut + pid0 * 6;
|
int *header = lut + pid0 * 6;
|
||||||
int offset = *(header + 0);
|
int offset = *(header + 0);
|
||||||
int AS1 = *(header + 1);
|
int AS1 = *(header + 1);
|
||||||
int column = *(header + 2);
|
int column = *(header + 2);
|
||||||
int depth = *(header + 3);
|
int depth = *(header + 3);
|
||||||
int lockid = *(header + 4);
|
int lockid = *(header + 4);
|
||||||
int maxid = *(header + 5);
|
int maxid = *(header + 5);
|
||||||
int *pinc = lut + offset;
|
int *pinc = lut + offset;
|
||||||
int offhc = depth;
|
int offhc = depth;
|
||||||
#ifdef DSD
|
#ifdef DSD
|
||||||
// output offset
|
// output offset
|
||||||
int offnc = pid1 * TN;
|
int offnc = pid1 * TN;
|
||||||
int offmc = column * TM;
|
int offmc = column * TM;
|
||||||
int offpc = 0;
|
int offpc = 0;
|
||||||
// dense input offset
|
// dense input offset
|
||||||
int offnb = pid1 * TN;
|
int offnb = pid1 * TN;
|
||||||
int offkb __multipleof(8) = *pinc;
|
int offkb __multipleof(8) = *pinc;
|
||||||
int offpb = 0;
|
int offpb = 0;
|
||||||
// sparse input offset
|
// sparse input offset
|
||||||
int offma = 0;
|
int offma = 0;
|
||||||
int offka = 0;
|
int offka = 0;
|
||||||
long offpa __multipleof(8) = *(pinc + 1);
|
long offpa __multipleof(8) = *(pinc + 1);
|
||||||
offpa = offpa * BLOCK * BLOCK;
|
offpa = offpa * BLOCK * BLOCK;
|
||||||
int offha = 0;
|
int offha = 0;
|
||||||
int offhb = depth;
|
int offhb = depth;
|
||||||
#endif
|
#endif
|
||||||
#ifdef DDS
|
#ifdef DDS
|
||||||
// output offset
|
// output offset
|
||||||
int offmc = pid1 * TM;
|
int offmc = pid1 * TM;
|
||||||
int offnc = column * TN;
|
int offnc = column * TN;
|
||||||
int offpc = 0;
|
int offpc = 0;
|
||||||
// dense input offset
|
// dense input offset
|
||||||
int offma = pid1 * TM;
|
int offma = pid1 * TM;
|
||||||
int offka __multipleof(8) = *pinc;
|
int offka __multipleof(8) = *pinc;
|
||||||
int offpa = 0;
|
int offpa = 0;
|
||||||
// sparse input offset
|
// sparse input offset
|
||||||
int offnb = 0;
|
int offnb = 0;
|
||||||
int offkb = 0;
|
int offkb = 0;
|
||||||
long offpb __multipleof(8) = *(pinc + 1);
|
long offpb __multipleof(8) = *(pinc + 1);
|
||||||
offpb = offpb * BLOCK * BLOCK;
|
offpb = offpb * BLOCK * BLOCK;
|
||||||
int offha = depth;
|
int offha = depth;
|
||||||
int offhb = 0;
|
int offhb = 0;
|
||||||
#endif
|
#endif
|
||||||
int ram[TM] = offma + 0 ... TM;
|
int ram[TM] = offma + 0 ... TM;
|
||||||
int rbn[TN] = offnb + 0 ... TN;
|
int rbn[TN] = offnb + 0 ... TN;
|
||||||
#endif
|
#endif
|
||||||
// initialize a, b pointers
|
// initialize a, b pointers
|
||||||
int rka[TK] = offka + 0 ... TK;
|
int rka[TK] = offka + 0 ... TK;
|
||||||
int rkb[TK] = offkb + 0 ... TK;
|
int rkb[TK] = offkb + 0 ... TK;
|
||||||
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
|
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
|
||||||
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
|
||||||
|
// pre-fetch
|
||||||
|
#ifdef DDS
|
||||||
|
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
||||||
|
#else
|
||||||
|
bool checkam[TM, TK] = AS1 > 0;
|
||||||
|
#endif
|
||||||
|
#ifdef DSD
|
||||||
|
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
|
||||||
|
#else
|
||||||
|
bool checkbn[TK, TN] = AS1 > 0;
|
||||||
|
#endif
|
||||||
|
TYPE a[TM, TK] = checkam ? *pa : 0;
|
||||||
|
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
||||||
|
|
||||||
|
/* ---------------- */
|
||||||
|
/* Inner Loop */
|
||||||
|
/* ---------------- */
|
||||||
|
// create result tile
|
||||||
|
float acc[TM, TN] = 0;
|
||||||
|
int step = TK;
|
||||||
|
for (int k = AS1; k > 0; k -= step) {
|
||||||
|
acc += a @b;
|
||||||
|
// update pointers
|
||||||
|
#ifdef SDD
|
||||||
|
int inc_a = TK * STRIDE_AK;
|
||||||
|
int inc_b = TK * STRIDE_BK;
|
||||||
|
#else
|
||||||
|
pinc += 2;
|
||||||
|
#ifdef DSD
|
||||||
|
int inc_b __multipleof(8) = *pinc;
|
||||||
|
int inc_a __multipleof(8) = *(pinc + 1);
|
||||||
|
inc_b = inc_b * STRIDE_BK;
|
||||||
|
#endif
|
||||||
|
#ifdef DDS
|
||||||
|
int inc_a __multipleof(8) = *pinc;
|
||||||
|
int inc_b __multipleof(8) = *(pinc + 1);
|
||||||
|
inc_a = inc_a * STRIDE_AK;
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
pa += inc_a;
|
||||||
|
pb += inc_b;
|
||||||
// pre-fetch
|
// pre-fetch
|
||||||
#ifdef DDS
|
bool checkak[TM, TK] = k > TK;
|
||||||
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
|
bool checkbk[TK, TN] = k > TK;
|
||||||
#else
|
bool checka[TM, TK] = checkam && checkak;
|
||||||
bool checkam[TM, TK] = AS1 > 0;
|
bool checkb[TK, TN] = checkbk && checkbn;
|
||||||
#endif
|
a = *? (checka)pa;
|
||||||
#ifdef DSD
|
b = *? (checkb)pb;
|
||||||
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
|
}
|
||||||
#else
|
TYPE c[TM, TN] = acc;
|
||||||
bool checkbn[TK, TN] = AS1 > 0;
|
|
||||||
#endif
|
|
||||||
TYPE a[TM, TK] = checkam ? *pa : 0;
|
|
||||||
TYPE b[TK, TN] = checkbn ? *pb : 0;
|
|
||||||
|
|
||||||
/* ---------------- */
|
/* ---------------- */
|
||||||
/* Inner Loop */
|
/* Epilogue */
|
||||||
/* ---------------- */
|
/* ---------------- */
|
||||||
// create result tile
|
// initialize c pointers
|
||||||
float acc[TM, TN] = 0;
|
|
||||||
int step = TK;
|
|
||||||
for(int k = AS1; k > 0; k -= step) {
|
|
||||||
acc += a @ b;
|
|
||||||
// update pointers
|
|
||||||
#ifdef SDD
|
#ifdef SDD
|
||||||
int inc_a = TK * STRIDE_AK;
|
bool checkc[TM, TN] = 1;
|
||||||
int inc_b = TK * STRIDE_BK;
|
// rematerialize
|
||||||
|
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
||||||
|
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
||||||
|
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
|
||||||
|
int rr_offlutn[TN] = rr_blockidn * 4;
|
||||||
|
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
|
||||||
|
int bkid[TM, TN] = *(header + off_bkid);
|
||||||
|
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
||||||
|
// range within blocks
|
||||||
|
int rcm[TM] = (0 ... TM) % BLOCK;
|
||||||
|
int rcn[TN] = (0 ... TN) % BLOCK;
|
||||||
#else
|
#else
|
||||||
pinc += 2;
|
int rcm[TM] = offmc + 0 ... TM;
|
||||||
|
int rcn[TN] = offnc + 0 ... TN;
|
||||||
#ifdef DSD
|
#ifdef DSD
|
||||||
int inc_b __multipleof(8) = *pinc;
|
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
|
||||||
int inc_a __multipleof(8) = *(pinc + 1);
|
|
||||||
inc_b = inc_b * STRIDE_BK;
|
|
||||||
#endif
|
#endif
|
||||||
#ifdef DDS
|
#ifdef DDS
|
||||||
int inc_a __multipleof(8) = *pinc;
|
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
||||||
int inc_b __multipleof(8) = *(pinc + 1);
|
|
||||||
inc_a = inc_a * STRIDE_AK;
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
pa += inc_a;
|
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
|
||||||
pb += inc_b;
|
// write-back directly
|
||||||
// pre-fetch
|
if (lockid == 0) {
|
||||||
bool checkak[TM, TK] = k > TK;
|
*? (checkc)pc = c;
|
||||||
bool checkbk[TK, TN] = k > TK;
|
}
|
||||||
bool checka[TM, TK] = checkam && checkak;
|
// accumulate partial result using spin-locks
|
||||||
bool checkb[TK, TN] = checkbk && checkbn;
|
else {
|
||||||
a = *?(checka)pa;
|
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
|
||||||
b = *?(checkb)pb;
|
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
|
||||||
}
|
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||||
TYPE c[TM, TN] = acc;
|
;
|
||||||
|
int count = *pcount;
|
||||||
/* ---------------- */
|
if (count == 0)
|
||||||
/* Epilogue */
|
*? (checkc)pc = c;
|
||||||
/* ---------------- */
|
else
|
||||||
// initialize c pointers
|
*? (checkc)pc = c + *? (checkc)pc;
|
||||||
#ifdef SDD
|
atomic_xchg(pcount, (count + 1) % maxid);
|
||||||
bool checkc[TM, TN] = 1;
|
atomic_xchg(plock, 0);
|
||||||
// rematerialize
|
}
|
||||||
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
|
}
|
||||||
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
|
|
||||||
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
|
|
||||||
int rr_offlutn[TN] = rr_blockidn*4;
|
|
||||||
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
|
|
||||||
int bkid[TM, TN] = *(header + off_bkid);
|
|
||||||
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
|
|
||||||
// range within blocks
|
|
||||||
int rcm[TM] = (0 ... TM) % BLOCK;
|
|
||||||
int rcn[TN] = (0 ... TN) % BLOCK;
|
|
||||||
#else
|
|
||||||
int rcm[TM] = offmc + 0 ... TM;
|
|
||||||
int rcn[TN] = offnc + 0 ... TN;
|
|
||||||
#ifdef DSD
|
|
||||||
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
|
|
||||||
#endif
|
|
||||||
#ifdef DDS
|
|
||||||
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
|
|
||||||
// write-back directly
|
|
||||||
if(lockid == 0) {
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
}
|
|
||||||
// accumulate partial result using spin-locks
|
|
||||||
else {
|
|
||||||
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
|
|
||||||
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
|
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
|
||||||
int count = *pcount;
|
|
||||||
if(count == 0)
|
|
||||||
*?(checkc) pc = c;
|
|
||||||
else
|
|
||||||
*?(checkc) pc = c + *?(checkc)pc;
|
|
||||||
atomic_xchg(pcount, (count + 1) % maxid);
|
|
||||||
atomic_xchg(plock, 0);
|
|
||||||
}
|
|
||||||
}
|
|
@@ -10,454 +10,416 @@ src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
|||||||
# MAIN API #
|
# MAIN API #
|
||||||
##############
|
##############
|
||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
|
|
||||||
sdd_cache = dict()
|
|
||||||
dsd_cache = dict()
|
|
||||||
dds_cache = dict()
|
|
||||||
locks = dict()
|
|
||||||
|
|
||||||
# Given an array sizes representing reduction size for each
|
sdd_cache = dict()
|
||||||
# column of a block-mode matrix multiplication,
|
dsd_cache = dict()
|
||||||
# performs load-balancing to achieve more smaller reductions
|
dds_cache = dict()
|
||||||
# between `seg_size` elements
|
locks = dict()
|
||||||
@staticmethod
|
|
||||||
def load_balance(sizes, block):
|
|
||||||
# segment size
|
|
||||||
# heuristics taken from OpenAI blocksparse code
|
|
||||||
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
|
||||||
max_size = sizes.max()
|
|
||||||
min_size = sizes[sizes != 0].min()
|
|
||||||
#if max_size > min_size * 2.0:
|
|
||||||
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
|
||||||
#else:
|
|
||||||
# seg_max = max_size
|
|
||||||
seg_max = max_size
|
|
||||||
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
|
||||||
# split reduction into segments
|
|
||||||
div = sizes // seg_max
|
|
||||||
rem = sizes % seg_max
|
|
||||||
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
|
||||||
width = packs.sum()
|
|
||||||
segments = torch.empty(width, dtype=sizes.dtype)
|
|
||||||
column = torch.empty_like(segments)
|
|
||||||
lockid = torch.zeros_like(segments)
|
|
||||||
maxid = torch.zeros_like(segments)
|
|
||||||
nlocks = 0
|
|
||||||
current = 0
|
|
||||||
col_idx = 0
|
|
||||||
for i in range(len(sizes)):
|
|
||||||
d, r = div[i], rem[i]
|
|
||||||
isempty = sizes[i] < seg_min
|
|
||||||
last = current + d + (r >= seg_min) + isempty
|
|
||||||
# column id
|
|
||||||
column[current:last] = col_idx
|
|
||||||
# lock id
|
|
||||||
if d > 1 or (d == 1 and r >= seg_min):
|
|
||||||
nlocks += 1
|
|
||||||
lockid[current:last] = nlocks
|
|
||||||
maxid[current:last] = last - current
|
|
||||||
# segment size
|
|
||||||
segments[current:current+d] = seg_max
|
|
||||||
if r < seg_min and not isempty:
|
|
||||||
segments[current+d-1] += r
|
|
||||||
if r >= seg_min or isempty:
|
|
||||||
segments[current+d] = r
|
|
||||||
current = last
|
|
||||||
col_idx += 1
|
|
||||||
offsets = torch.zeros_like(segments)
|
|
||||||
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
|
||||||
return segments, column, lockid, maxid, offsets
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_locks(size, dev):
|
|
||||||
if dev not in _matmul.locks or \
|
|
||||||
size > _matmul.locks[dev].size(0):
|
|
||||||
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
|
||||||
return _matmul.locks[dev]
|
|
||||||
|
|
||||||
##########################
|
# Given an array sizes representing reduction size for each
|
||||||
# SPARSE = DENSE x DENSE #
|
# column of a block-mode matrix multiplication,
|
||||||
##########################
|
# performs load-balancing to achieve more smaller reductions
|
||||||
|
# between `seg_size` elements
|
||||||
|
@staticmethod
|
||||||
|
def load_balance(sizes, block):
|
||||||
|
# segment size
|
||||||
|
# heuristics taken from OpenAI blocksparse code
|
||||||
|
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
|
||||||
|
max_size = sizes.max()
|
||||||
|
min_size = sizes[sizes != 0].min()
|
||||||
|
#if max_size > min_size * 2.0:
|
||||||
|
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
|
||||||
|
#else:
|
||||||
|
# seg_max = max_size
|
||||||
|
seg_max = max_size
|
||||||
|
seg_min = max(triton.cdiv(seg_max, 4), 4)
|
||||||
|
# split reduction into segments
|
||||||
|
div = sizes // seg_max
|
||||||
|
rem = sizes % seg_max
|
||||||
|
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
|
||||||
|
width = packs.sum()
|
||||||
|
segments = torch.empty(width, dtype=sizes.dtype)
|
||||||
|
column = torch.empty_like(segments)
|
||||||
|
lockid = torch.zeros_like(segments)
|
||||||
|
maxid = torch.zeros_like(segments)
|
||||||
|
nlocks = 0
|
||||||
|
current = 0
|
||||||
|
col_idx = 0
|
||||||
|
for i in range(len(sizes)):
|
||||||
|
d, r = div[i], rem[i]
|
||||||
|
isempty = sizes[i] < seg_min
|
||||||
|
last = current + d + (r >= seg_min) + isempty
|
||||||
|
# column id
|
||||||
|
column[current:last] = col_idx
|
||||||
|
# lock id
|
||||||
|
if d > 1 or (d == 1 and r >= seg_min):
|
||||||
|
nlocks += 1
|
||||||
|
lockid[current:last] = nlocks
|
||||||
|
maxid[current:last] = last - current
|
||||||
|
# segment size
|
||||||
|
segments[current:current + d] = seg_max
|
||||||
|
if r < seg_min and not isempty:
|
||||||
|
segments[current + d - 1] += r
|
||||||
|
if r >= seg_min or isempty:
|
||||||
|
segments[current + d] = r
|
||||||
|
current = last
|
||||||
|
col_idx += 1
|
||||||
|
offsets = torch.zeros_like(segments)
|
||||||
|
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
|
||||||
|
return segments, column, lockid, maxid, offsets
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_sdd_lut(layout, block, dtype, device):
|
def get_locks(size, dev):
|
||||||
start_width = 128 // block
|
if dev not in _matmul.locks or \
|
||||||
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
size > _matmul.locks[dev].size(0):
|
||||||
luts, widths, packs = [], [], []
|
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
|
||||||
for size, nnz in superblocks:
|
return _matmul.locks[dev]
|
||||||
width = nnz.shape[0] // (size*size)
|
|
||||||
h = nnz[:, 0]
|
|
||||||
i = nnz[:, 1]
|
|
||||||
j = nnz[:, 2]
|
|
||||||
b = nnz[:, 3]
|
|
||||||
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
|
||||||
luts.append(lut.type(torch.int32).to(device))
|
|
||||||
widths.append(width)
|
|
||||||
packs.append(size)
|
|
||||||
# create locks
|
|
||||||
return luts, None, widths, packs
|
|
||||||
|
|
||||||
@staticmethod
|
##########################
|
||||||
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
|
# SPARSE = DENSE x DENSE #
|
||||||
spdims, block, luts, num_locks, widths, packs):
|
##########################
|
||||||
|
|
||||||
if trans_c:
|
|
||||||
a, b = b, a
|
|
||||||
trans_a, trans_b = not trans_b, not trans_a
|
|
||||||
AS0 = a.size(0)
|
|
||||||
AS1 = a.size(1)
|
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
|
||||||
AS3 = a.size(2 if trans_a else 3)
|
|
||||||
BS0 = b.size(0)
|
|
||||||
BS1 = b.size(1)
|
|
||||||
BS2 = b.size(3 if trans_b else 2)
|
|
||||||
BS3 = b.size(2 if trans_b else 3)
|
|
||||||
dtype = a.dtype
|
|
||||||
device = a.device
|
|
||||||
is_16_multiple = AS3 % 16 == 0
|
|
||||||
is_32_multiple = AS3 % 32 == 0
|
|
||||||
is_64_multiple = AS3 % 64 == 0
|
|
||||||
if not is_16_multiple:
|
|
||||||
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
|
||||||
# create kernel
|
|
||||||
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
|
|
||||||
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
|
||||||
for lut, width, pack in zip(luts, widths, packs):
|
|
||||||
num_lock = 1
|
|
||||||
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
|
||||||
if key not in _matmul.sdd_cache:
|
|
||||||
defines = {'TM': block*pack, 'TN': block*pack,
|
|
||||||
'TMN': block*block*pack*pack,
|
|
||||||
'BLOCK': block,
|
|
||||||
'TK': 32,
|
|
||||||
'TYPE': dtype,
|
|
||||||
'STRIDE_AM': '1' if trans_a else 'lda',
|
|
||||||
'STRIDE_AK': 'lda' if trans_a else '1',
|
|
||||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
|
||||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
|
||||||
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
|
|
||||||
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
|
|
||||||
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
|
||||||
|
|
||||||
kernel = _matmul.sdd_cache[key]
|
@staticmethod
|
||||||
# create output
|
def make_sdd_lut(layout, block, dtype, device):
|
||||||
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
|
start_width = 128 // block
|
||||||
# maximum grid size is 65535
|
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
|
||||||
# so operation might be decomposed into multiple
|
luts, widths, packs = [], [], []
|
||||||
# kernel calls
|
for size, nnz in superblocks:
|
||||||
max_width = 49152
|
width = nnz.shape[0] // (size * size)
|
||||||
for off_width in range(0, width, max_width):
|
h = nnz[:, 0]
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
i = nnz[:, 1]
|
||||||
a.stride(2), b.stride(2), block,
|
j = nnz[:, 2]
|
||||||
a.stride(0), b.stride(0), c.stride(0),
|
b = nnz[:, 3]
|
||||||
a.stride(1), b.stride(1), c.stride(0),
|
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
|
||||||
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
|
luts.append(lut.type(torch.int32).to(device))
|
||||||
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
widths.append(width)
|
||||||
# save for backward pass
|
packs.append(size)
|
||||||
return c
|
# create locks
|
||||||
|
return luts, None, widths, packs
|
||||||
|
|
||||||
##########################
|
@staticmethod
|
||||||
# DENSE = DENSE x SPARSE #
|
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
|
||||||
# DENSE = SPARSE x DENSE #
|
|
||||||
##########################
|
|
||||||
|
|
||||||
# Given a binary layout of 0s and 1s,
|
|
||||||
# Construct look-up table for efficient execution on GPUs
|
|
||||||
@staticmethod
|
|
||||||
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
|
|
||||||
# load-balancing
|
|
||||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
|
||||||
segments = _empty.clone()
|
|
||||||
column = _empty.clone()
|
|
||||||
depth = _empty.clone()
|
|
||||||
lockid = _empty.clone()
|
|
||||||
maxid = _empty.clone()
|
|
||||||
offsets = _empty.clone()
|
|
||||||
current_offset = 0
|
|
||||||
current_maxid = 0
|
|
||||||
for z in range(layout.size(0)):
|
|
||||||
if trans:
|
|
||||||
sizes = torch.sum(layout[z, :, :], 1)
|
|
||||||
else:
|
|
||||||
sizes = torch.sum(layout[z, :, :], 0)
|
|
||||||
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
|
||||||
z_depth = z * torch.ones_like(z_segments)
|
|
||||||
z_lockid[z_lockid > 0] += current_maxid
|
|
||||||
current_maxid = z_lockid.max()
|
|
||||||
# concatenate depth
|
|
||||||
segments = torch.cat((segments, z_segments))
|
|
||||||
column = torch.cat((column, z_column))
|
|
||||||
depth = torch.cat((depth, z_depth))
|
|
||||||
maxid = torch.cat((maxid, z_maxid))
|
|
||||||
offsets = torch.cat((offsets, current_offset + z_offsets))
|
|
||||||
lockid = torch.cat((lockid, z_lockid))
|
|
||||||
current_offset += layout[z, :, :].sum()
|
|
||||||
segments *= step
|
|
||||||
# pointer increments
|
|
||||||
if trans:
|
|
||||||
nnz = layout.nonzero(as_tuple=False)
|
|
||||||
else:
|
|
||||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
|
||||||
num_blocks = nnz.size(0)
|
|
||||||
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
|
|
||||||
idx = transform(nnz[:, 2]*block)
|
|
||||||
xincs = idx.clone()
|
|
||||||
xincs[1:] -= idx[:-1]
|
|
||||||
# divide block into multiple steps
|
|
||||||
div = block // step
|
|
||||||
xincs = xincs.view(-1, 1).repeat(1, div)
|
|
||||||
xincs[:, 1:] = step
|
|
||||||
xincs[:, 0 ] -= (div-1)*step
|
|
||||||
# first increment for each reduction is actually the offset
|
|
||||||
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
|
|
||||||
xincs = xincs.view(-1)
|
|
||||||
# block-mode input increments
|
|
||||||
if trans:
|
|
||||||
widx = torch.arange(num_blocks)
|
|
||||||
else:
|
|
||||||
widx = _empty.clone()
|
|
||||||
current_offset = 0
|
|
||||||
for z in range(layout.size(0)):
|
|
||||||
layoutw = layout[z, :, :].clone()
|
|
||||||
msum = layoutw.sum()
|
|
||||||
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
|
||||||
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
|
||||||
current_offset += msum
|
|
||||||
widx = widx
|
|
||||||
wincs = widx*block*block
|
|
||||||
wincs[1:] -= widx[:-1]*block*block
|
|
||||||
wincs = wincs.view(-1, 1).repeat(1, div)
|
|
||||||
if trans:
|
|
||||||
wincs[:, 1:] = step
|
|
||||||
wincs[:, 0] -= (div-1)*step
|
|
||||||
else:
|
|
||||||
wincs[:, 1:] = step*block
|
|
||||||
wincs[:, 0] -= (div - 1)*step*block
|
|
||||||
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
|
|
||||||
wincs = wincs.view(-1)
|
|
||||||
# adjust offset and segment size
|
|
||||||
offsets *= 2*div
|
|
||||||
segments *= div
|
|
||||||
# create header
|
|
||||||
width = column.size(0)
|
|
||||||
offsets += 6*width
|
|
||||||
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
|
|
||||||
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
|
||||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
|
||||||
# create lut
|
|
||||||
lut = torch.cat((header, incs))
|
|
||||||
lut = lut.type(torch.int32).to(device)
|
|
||||||
# create locks
|
|
||||||
num_locks = max(1, lockid.max())
|
|
||||||
return lut, num_locks, width, None
|
|
||||||
|
|
||||||
@staticmethod
|
if trans_c:
|
||||||
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
|
a, b = b, a
|
||||||
spdims, block, lut, num_locks, width, packs):
|
trans_a, trans_b = not trans_b, not trans_a
|
||||||
# shapes / dtypes
|
AS0 = a.size(0)
|
||||||
AS0 = a.size(0)
|
AS1 = a.size(1)
|
||||||
AS1 = a.size(1)
|
AS2 = a.size(3 if trans_a else 2)
|
||||||
AS2 = a.size(3 if trans_a else 2)
|
AS3 = a.size(2 if trans_a else 3)
|
||||||
AS3 = a.size(2 if trans_a else 3)
|
BS0 = b.size(0)
|
||||||
BS0 = spdims[0]
|
BS1 = b.size(1)
|
||||||
BS1 = block * spdims[2 if trans_b else 1]
|
BS2 = b.size(3 if trans_b else 2)
|
||||||
BS2 = block * spdims[1 if trans_b else 2]
|
BS3 = b.size(2 if trans_b else 3)
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
# kernel
|
device = a.device
|
||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
is_16_multiple = AS3 % 16 == 0
|
||||||
if key not in _matmul.dds_cache:
|
is_32_multiple = AS3 % 32 == 0
|
||||||
defines = {'TM': 128,
|
is_64_multiple = AS3 % 64 == 0
|
||||||
'TN': block,
|
if not is_16_multiple:
|
||||||
'TK': 16,
|
raise ValueError('Reduction size for SDD must be a multiple of 16')
|
||||||
'BLOCK': block,
|
# create kernel
|
||||||
'TYPE': dtype,
|
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
|
||||||
'STRIDE_AM': 1 if trans_a else 'lda',
|
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
|
||||||
'STRIDE_AK': 'lda' if trans_a else 1,
|
for lut, width, pack in zip(luts, widths, packs):
|
||||||
'STRIDE_BN': block if trans_b else 1,
|
num_lock = 1
|
||||||
'STRIDE_BK': 1 if trans_b else block,
|
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
|
||||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
if key not in _matmul.sdd_cache:
|
||||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
defines = {
|
||||||
'NAME': 'dds_kernel',
|
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
|
||||||
'DDS': True}
|
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
|
||||||
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
|
||||||
kernel = _matmul.dds_cache[key]
|
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
|
||||||
# output
|
}
|
||||||
CS0 = AS0
|
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
|
||||||
CS1 = AS1
|
|
||||||
CS2 = BS2 if trans_c else AS2
|
|
||||||
CS3 = AS2 if trans_c else BS2
|
|
||||||
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
|
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
|
||||||
a.stride(2), block, c.stride(2),
|
|
||||||
a.stride(0), b.stride(0), c.stride(0),
|
|
||||||
a.stride(1), b.stride(1), c.stride(1),
|
|
||||||
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
|
||||||
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
|
||||||
return c
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
|
|
||||||
spdims, block, lut, num_locks, width, packs):
|
|
||||||
# shapes / dtypes
|
|
||||||
AS0 = spdims[0]
|
|
||||||
AS1 = block * spdims[2 if trans_a else 1]
|
|
||||||
AS2 = block * spdims[1 if trans_a else 2]
|
|
||||||
BS0 = b.size(0)
|
|
||||||
BS1 = b.size(1)
|
|
||||||
BS2 = b.size(3 if trans_b else 2)
|
|
||||||
BS3 = b.size(2 if trans_b else 3)
|
|
||||||
dtype = a.dtype
|
|
||||||
# kernel
|
|
||||||
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
|
||||||
if key not in _matmul.dsd_cache:
|
|
||||||
defines = {'TM': block,
|
|
||||||
'TN': 128,
|
|
||||||
'TK': 16,
|
|
||||||
'BLOCK': block,
|
|
||||||
'TYPE': dtype,
|
|
||||||
'STRIDE_AM': 1 if trans_a else block,
|
|
||||||
'STRIDE_AK': block if trans_a else 1,
|
|
||||||
'STRIDE_BN': 'ldb' if trans_b else '1',
|
|
||||||
'STRIDE_BK': '1' if trans_b else 'ldb',
|
|
||||||
'STRIDE_CM': '1' if trans_c else 'ldc',
|
|
||||||
'STRIDE_CN': 'ldc' if trans_c else '1',
|
|
||||||
'NAME': 'dsd_kernel',
|
|
||||||
'DSD': True}
|
|
||||||
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
|
||||||
kernel = _matmul.dsd_cache[key]
|
|
||||||
# output
|
|
||||||
CS0 = BS0
|
|
||||||
CS1 = BS1
|
|
||||||
CS2 = BS3 if trans_c else AS1
|
|
||||||
CS3 = AS1 if trans_c else BS3
|
|
||||||
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
|
|
||||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
|
||||||
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
|
||||||
block, b.stride(2), c.stride(2),
|
|
||||||
a.stride(0), b.stride(0), c.stride(0),
|
|
||||||
a.stride(1), b.stride(1), c.stride(1),
|
|
||||||
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
|
|
||||||
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
|
||||||
return c
|
|
||||||
|
|
||||||
fn = {'sdd': _sdd_matmul.__get__(object),
|
kernel = _matmul.sdd_cache[key]
|
||||||
'dsd': _dsd_matmul.__get__(object),
|
# create output
|
||||||
'dds': _dds_matmul.__get__(object)}
|
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
|
||||||
|
# maximum grid size is 65535
|
||||||
|
# so operation might be decomposed into multiple
|
||||||
|
# kernel calls
|
||||||
|
max_width = 49152
|
||||||
|
for off_width in range(0, width, max_width):
|
||||||
|
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
|
||||||
|
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
|
||||||
|
lut.data_ptr(), locks.data_ptr(), num_lock,
|
||||||
|
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
|
||||||
|
# save for backward pass
|
||||||
|
return c
|
||||||
|
|
||||||
@staticmethod
|
##########################
|
||||||
def forward(ctx, a, b, trans_a, trans_b, trans_c,
|
# DENSE = DENSE x SPARSE #
|
||||||
mode, spdims, block,
|
# DENSE = SPARSE x DENSE #
|
||||||
c_lut, c_num_locks, c_width, c_packs,
|
##########################
|
||||||
da_lut, da_num_locks, da_width, da_packs,
|
|
||||||
db_lut, db_num_locks, db_width, db_packs):
|
|
||||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
|
|
||||||
c_lut, c_num_locks, c_width, c_packs)
|
|
||||||
# save for backward
|
|
||||||
ctx.save_for_backward(a, b)
|
|
||||||
ctx.da_num_locks = da_num_locks
|
|
||||||
ctx.da_lut = da_lut
|
|
||||||
ctx.da_width = da_width
|
|
||||||
ctx.da_packs = da_packs
|
|
||||||
ctx.db_lut = db_lut
|
|
||||||
ctx.db_num_locks = db_num_locks
|
|
||||||
ctx.db_width = db_width
|
|
||||||
ctx.db_packs = db_packs
|
|
||||||
ctx.mode = mode
|
|
||||||
ctx.spdims = spdims
|
|
||||||
ctx.block = block
|
|
||||||
ctx.trans_a = trans_a
|
|
||||||
ctx.trans_b = trans_b
|
|
||||||
return c
|
|
||||||
|
|
||||||
@staticmethod
|
# Given a binary layout of 0s and 1s,
|
||||||
def backward(ctx, dc):
|
# Construct look-up table for efficient execution on GPUs
|
||||||
# saved for backward
|
@staticmethod
|
||||||
a, b = ctx.saved_tensors
|
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
|
||||||
mode = ctx.mode
|
# load-balancing
|
||||||
# gradients w.r.t. a
|
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||||
if ctx.needs_input_grad[0]:
|
segments = _empty.clone()
|
||||||
mode_da = mode[1] + mode[0] + mode[2]
|
column = _empty.clone()
|
||||||
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
|
depth = _empty.clone()
|
||||||
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
lockid = _empty.clone()
|
||||||
# gradients w.r.t. b
|
maxid = _empty.clone()
|
||||||
if ctx.needs_input_grad[1]:
|
offsets = _empty.clone()
|
||||||
mode_db = mode[2] + mode[1] + mode[0]
|
current_offset = 0
|
||||||
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
|
current_maxid = 0
|
||||||
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
for z in range(layout.size(0)):
|
||||||
return da, db, None, None, None,\
|
if trans:
|
||||||
None, None, None, None,\
|
sizes = torch.sum(layout[z, :, :], 1)
|
||||||
None, None, None, None, None, None,\
|
else:
|
||||||
None, None, None, None, None, None,\
|
sizes = torch.sum(layout[z, :, :], 0)
|
||||||
None, None, None, None, None, None
|
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
|
||||||
|
z_depth = z * torch.ones_like(z_segments)
|
||||||
|
z_lockid[z_lockid > 0] += current_maxid
|
||||||
|
current_maxid = z_lockid.max()
|
||||||
|
# concatenate depth
|
||||||
|
segments = torch.cat((segments, z_segments))
|
||||||
|
column = torch.cat((column, z_column))
|
||||||
|
depth = torch.cat((depth, z_depth))
|
||||||
|
maxid = torch.cat((maxid, z_maxid))
|
||||||
|
offsets = torch.cat((offsets, current_offset + z_offsets))
|
||||||
|
lockid = torch.cat((lockid, z_lockid))
|
||||||
|
current_offset += layout[z, :, :].sum()
|
||||||
|
segments *= step
|
||||||
|
# pointer increments
|
||||||
|
if trans:
|
||||||
|
nnz = layout.nonzero(as_tuple=False)
|
||||||
|
else:
|
||||||
|
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||||
|
num_blocks = nnz.size(0)
|
||||||
|
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||||
|
idx = transform(nnz[:, 2] * block)
|
||||||
|
xincs = idx.clone()
|
||||||
|
xincs[1:] -= idx[:-1]
|
||||||
|
# divide block into multiple steps
|
||||||
|
div = block // step
|
||||||
|
xincs = xincs.view(-1, 1).repeat(1, div)
|
||||||
|
xincs[:, 1:] = step
|
||||||
|
xincs[:, 0] -= (div - 1) * step
|
||||||
|
# first increment for each reduction is actually the offset
|
||||||
|
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
|
||||||
|
xincs = xincs.view(-1)
|
||||||
|
# block-mode input increments
|
||||||
|
if trans:
|
||||||
|
widx = torch.arange(num_blocks)
|
||||||
|
else:
|
||||||
|
widx = _empty.clone()
|
||||||
|
current_offset = 0
|
||||||
|
for z in range(layout.size(0)):
|
||||||
|
layoutw = layout[z, :, :].clone()
|
||||||
|
msum = layoutw.sum()
|
||||||
|
layoutw[layoutw > 0] = 1 + torch.arange(msum)
|
||||||
|
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||||
|
current_offset += msum
|
||||||
|
widx = widx
|
||||||
|
wincs = widx * block * block
|
||||||
|
wincs[1:] -= widx[:-1] * block * block
|
||||||
|
wincs = wincs.view(-1, 1).repeat(1, div)
|
||||||
|
if trans:
|
||||||
|
wincs[:, 1:] = step
|
||||||
|
wincs[:, 0] -= (div - 1) * step
|
||||||
|
else:
|
||||||
|
wincs[:, 1:] = step * block
|
||||||
|
wincs[:, 0] -= (div - 1) * step * block
|
||||||
|
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
|
||||||
|
wincs = wincs.view(-1)
|
||||||
|
# adjust offset and segment size
|
||||||
|
offsets *= 2 * div
|
||||||
|
segments *= div
|
||||||
|
# create header
|
||||||
|
width = column.size(0)
|
||||||
|
offsets += 6 * width
|
||||||
|
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
|
||||||
|
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
|
||||||
|
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||||
|
# create lut
|
||||||
|
lut = torch.cat((header, incs))
|
||||||
|
lut = lut.type(torch.int32).to(device)
|
||||||
|
# create locks
|
||||||
|
num_locks = max(1, lockid.max())
|
||||||
|
return lut, num_locks, width, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||||
|
# shapes / dtypes
|
||||||
|
AS0 = a.size(0)
|
||||||
|
AS1 = a.size(1)
|
||||||
|
AS2 = a.size(3 if trans_a else 2)
|
||||||
|
AS3 = a.size(2 if trans_a else 3)
|
||||||
|
BS0 = spdims[0]
|
||||||
|
BS1 = block * spdims[2 if trans_b else 1]
|
||||||
|
BS2 = block * spdims[1 if trans_b else 2]
|
||||||
|
dtype = a.dtype
|
||||||
|
# kernel
|
||||||
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
|
if key not in _matmul.dds_cache:
|
||||||
|
defines = {
|
||||||
|
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
|
||||||
|
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
|
||||||
|
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
|
||||||
|
'NAME': 'dds_kernel', 'DDS': True
|
||||||
|
}
|
||||||
|
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
|
kernel = _matmul.dds_cache[key]
|
||||||
|
# output
|
||||||
|
CS0 = AS0
|
||||||
|
CS1 = AS1
|
||||||
|
CS2 = BS2 if trans_c else AS2
|
||||||
|
CS3 = AS2 if trans_c else BS2
|
||||||
|
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
|
||||||
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
|
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
|
||||||
|
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
||||||
|
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
|
||||||
|
return c
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
|
||||||
|
# shapes / dtypes
|
||||||
|
AS0 = spdims[0]
|
||||||
|
AS1 = block * spdims[2 if trans_a else 1]
|
||||||
|
AS2 = block * spdims[1 if trans_a else 2]
|
||||||
|
BS0 = b.size(0)
|
||||||
|
BS1 = b.size(1)
|
||||||
|
BS2 = b.size(3 if trans_b else 2)
|
||||||
|
BS3 = b.size(2 if trans_b else 3)
|
||||||
|
dtype = a.dtype
|
||||||
|
# kernel
|
||||||
|
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
|
||||||
|
if key not in _matmul.dsd_cache:
|
||||||
|
defines = {
|
||||||
|
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
|
||||||
|
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
|
||||||
|
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
|
||||||
|
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
|
||||||
|
}
|
||||||
|
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
|
||||||
|
kernel = _matmul.dsd_cache[key]
|
||||||
|
# output
|
||||||
|
CS0 = BS0
|
||||||
|
CS1 = BS1
|
||||||
|
CS2 = BS3 if trans_c else AS1
|
||||||
|
CS3 = AS1 if trans_c else BS3
|
||||||
|
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
|
||||||
|
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||||
|
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
|
||||||
|
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
|
||||||
|
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
|
||||||
|
return c
|
||||||
|
|
||||||
|
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
|
||||||
|
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
|
||||||
|
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
|
||||||
|
# save for backward
|
||||||
|
ctx.save_for_backward(a, b)
|
||||||
|
ctx.da_num_locks = da_num_locks
|
||||||
|
ctx.da_lut = da_lut
|
||||||
|
ctx.da_width = da_width
|
||||||
|
ctx.da_packs = da_packs
|
||||||
|
ctx.db_lut = db_lut
|
||||||
|
ctx.db_num_locks = db_num_locks
|
||||||
|
ctx.db_width = db_width
|
||||||
|
ctx.db_packs = db_packs
|
||||||
|
ctx.mode = mode
|
||||||
|
ctx.spdims = spdims
|
||||||
|
ctx.block = block
|
||||||
|
ctx.trans_a = trans_a
|
||||||
|
ctx.trans_b = trans_b
|
||||||
|
return c
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dc):
|
||||||
|
# saved for backward
|
||||||
|
a, b = ctx.saved_tensors
|
||||||
|
mode = ctx.mode
|
||||||
|
# gradients w.r.t. a
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
mode_da = mode[1] + mode[0] + mode[2]
|
||||||
|
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
|
||||||
|
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
|
||||||
|
# gradients w.r.t. b
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
mode_db = mode[2] + mode[1] + mode[0]
|
||||||
|
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
|
||||||
|
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
|
||||||
|
return da, db, None, None, None,\
|
||||||
|
None, None, None, None,\
|
||||||
|
None, None, None, None, None, None,\
|
||||||
|
None, None, None, None, None, None,\
|
||||||
|
None, None, None, None, None, None
|
||||||
|
|
||||||
class matmul:
|
class matmul:
|
||||||
|
def make_lut(self, dtype, device):
|
||||||
def make_lut(self, dtype, device):
|
key = (dtype, device)
|
||||||
key = (dtype, device)
|
if key in self.lut_cache:
|
||||||
if key in self.lut_cache:
|
return self.lut_cache[key]
|
||||||
return self.lut_cache[key]
|
# C look-up table
|
||||||
# C look-up table
|
layout, block = self.layout, self.block
|
||||||
layout, block = self.layout, self.block
|
step = 8 if dtype == torch.float32 else 16
|
||||||
step = 8 if dtype == torch.float32 else 16
|
if self.mode == 'sdd':
|
||||||
if self.mode == 'sdd':
|
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
elif self.mode == 'dsd':
|
||||||
elif self.mode == 'dsd':
|
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
||||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
|
elif self.mode == 'dds':
|
||||||
elif self.mode == 'dds':
|
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
||||||
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
|
# DA look-up table
|
||||||
# DA look-up table
|
if self.mode == 'sdd':
|
||||||
if self.mode == 'sdd':
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
|
elif self.mode == 'dsd':
|
||||||
elif self.mode == 'dsd':
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
elif self.mode == 'dds':
|
||||||
elif self.mode == 'dds':
|
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
|
||||||
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
|
device)
|
||||||
# DB look-up table
|
# DB look-up table
|
||||||
if self.mode == 'sdd':
|
if self.mode == 'sdd':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
|
||||||
elif self.mode == 'dsd':
|
elif self.mode == 'dsd':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
|
||||||
elif self.mode == 'dds':
|
elif self.mode == 'dds':
|
||||||
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
|
||||||
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
da_lut, da_num_locks, da_width, da_packs,\
|
||||||
db_lut, db_num_locks, db_width, db_packs)
|
db_lut, db_num_locks, db_width, db_packs)
|
||||||
return self.lut_cache[key]
|
return self.lut_cache[key]
|
||||||
|
|
||||||
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
|
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
|
||||||
if mode not in ['sdd', 'dsd', 'dds']:
|
if mode not in ['sdd', 'dsd', 'dds']:
|
||||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||||
# look-up table cache
|
# look-up table cache
|
||||||
self.lut_cache = dict()
|
self.lut_cache = dict()
|
||||||
# attributes
|
# attributes
|
||||||
self.trans_a = trans_a
|
self.trans_a = trans_a
|
||||||
self.trans_b = trans_b
|
self.trans_b = trans_b
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.spdims = layout.shape
|
self.spdims = layout.shape
|
||||||
self.block = block
|
self.block = block
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
|
|
||||||
# pad shapes of a tensor to make it
|
|
||||||
# compatible with kernel calls
|
|
||||||
@staticmethod
|
|
||||||
def _pad_shape(x, is_sparse):
|
|
||||||
max_dim = 3 if is_sparse else 4
|
|
||||||
for i in range(max_dim - x.dim()):
|
|
||||||
x = x.unsqueeze(0)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __call__(self, a, b):
|
# pad shapes of a tensor to make it
|
||||||
c_lut, c_num_locks, c_width, c_packs,\
|
# compatible with kernel calls
|
||||||
da_lut, da_num_locks, da_width, da_packs,\
|
@staticmethod
|
||||||
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
def _pad_shape(x, is_sparse):
|
||||||
# pad shapes with ones
|
max_dim = 3 if is_sparse else 4
|
||||||
a = matmul._pad_shape(a, self.mode == 'dsd')
|
for i in range(max_dim - x.dim()):
|
||||||
b = matmul._pad_shape(b, self.mode == 'dds')
|
x = x.unsqueeze(0)
|
||||||
# execute
|
return x
|
||||||
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
|
|
||||||
self.mode, self.spdims, self.block,
|
def __call__(self, a, b):
|
||||||
c_lut, c_num_locks, c_width, c_packs,
|
c_lut, c_num_locks, c_width, c_packs,\
|
||||||
da_lut, da_num_locks, da_width, da_packs,
|
da_lut, da_num_locks, da_width, da_packs,\
|
||||||
db_lut, db_num_locks, db_width, db_packs)
|
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
|
||||||
return c
|
# pad shapes with ones
|
||||||
|
a = matmul._pad_shape(a, self.mode == 'dsd')
|
||||||
|
b = matmul._pad_shape(b, self.mode == 'dds')
|
||||||
|
# execute
|
||||||
|
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
|
||||||
|
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
|
||||||
|
db_width, db_packs)
|
||||||
|
return c
|
||||||
|
@@ -1,9 +1,9 @@
|
|||||||
#define STM 8
|
#define STM 8
|
||||||
#define STN 8
|
#define STN 8
|
||||||
|
|
||||||
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
|
||||||
TYPE * B __noalias __readonly __aligned(16),
|
TYPE *B __noalias __readonly __aligned(16),
|
||||||
TYPE * C __noalias __aligned(16),
|
TYPE *C __noalias __aligned(16),
|
||||||
float alpha,
|
float alpha,
|
||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
@@ -11,87 +11,88 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
int lda __multipleof(LDA_POW2_DIV),
|
int lda __multipleof(LDA_POW2_DIV),
|
||||||
int ldb __multipleof(LDB_POW2_DIV),
|
int ldb __multipleof(LDB_POW2_DIV),
|
||||||
int ldc __multipleof(LDC_POW2_DIV),
|
int ldc __multipleof(LDC_POW2_DIV),
|
||||||
int* locks) {
|
int *locks) {
|
||||||
// prologue
|
// prologue
|
||||||
int pid = get_program_id(0);
|
int pid = get_program_id(0);
|
||||||
int pidz = get_program_id(2);
|
int pidz = get_program_id(2);
|
||||||
int gridm = (M + TM - 1) / TM;
|
int gridm = (M + TM - 1) / TM;
|
||||||
int gridn = (N + TN - 1) / TN;
|
int gridn = (N + TN - 1) / TN;
|
||||||
|
|
||||||
// swizzle for better L2 performance
|
// swizzle for better L2 performance
|
||||||
int width = STM*gridn;
|
int width = STM * gridn;
|
||||||
int stm = pid / width;
|
int stm = pid / width;
|
||||||
int RSTM = min(gridm - stm*STM, STM);
|
int RSTM = min(gridm - stm * STM, STM);
|
||||||
int stn = (pid % width) / (RSTM*STN);
|
int stn = (pid % width) / (RSTM * STN);
|
||||||
int RSTN = min(gridn - stn*STN, STN);
|
int RSTN = min(gridn - stn * STN, STN);
|
||||||
int laneid = pid % (RSTM * RSTN);
|
int laneid = pid % (RSTM * RSTN);
|
||||||
int lanem = laneid / RSTN;
|
int lanem = laneid / RSTN;
|
||||||
int lanen = laneid % RSTN;
|
int lanen = laneid % RSTN;
|
||||||
int pidm = stm*STM + lanem;
|
int pidm = stm * STM + lanem;
|
||||||
int pidn = stn*STN + lanen;
|
int pidn = stn * STN + lanen;
|
||||||
int rm[TM] = pidm * TM + 0 ... TM;
|
int rm[TM] = pidm * TM + 0 ... TM;
|
||||||
int rn[TN] = pidn * TN + 0 ... TN;
|
int rn[TN] = pidn * TN + 0 ... TN;
|
||||||
|
|
||||||
// split-k for better parrallelism
|
// split-k for better parrallelism
|
||||||
K = K / TZ;
|
K = K / SPLITK;
|
||||||
int rk[TK] = 0 ... TK;
|
int rk[TK] = 0 ... TK;
|
||||||
// pointers to operands
|
// pointers to operands
|
||||||
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||||
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
|
||||||
TYPE* pa[TM, TK] = A + offa;
|
TYPE *pa[TM, TK] = A + offa;
|
||||||
TYPE* pb[TK, TN] = B + offb;
|
TYPE *pb[TK, TN] = B + offb;
|
||||||
|
|
||||||
// prefetches operands
|
// prefetches operands
|
||||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
bool checka[TM, TK] = rk [newaxis, :] < K;
|
||||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||||
pa += TK * STRIDE_AK;
|
pa += TK * STRIDE_AK;
|
||||||
pb += TK * STRIDE_BK;
|
pb += TK * STRIDE_BK;
|
||||||
|
|
||||||
// reduction loop
|
// reduction loop
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
for(int k = K; k > 0; k -= TK){
|
for (int k = K; k > 0; k -= TK) {
|
||||||
#if (IS_TK_DIV_K==1)
|
#if (IS_TK_DIV_K == 1)
|
||||||
bool checkk[TK] = k > TK;
|
bool checkk[TK] = k > TK;
|
||||||
#else
|
#else
|
||||||
bool checkk[TK] = rk < k - TK;
|
bool checkk[TK] = rk < k - TK;
|
||||||
#endif
|
#endif
|
||||||
bool checka[TM, TK] = checkk[newaxis, :];
|
bool checka[TM, TK] = checkk [newaxis, :];
|
||||||
bool checkb[TK, TN] = checkk[:, newaxis];
|
bool checkb[TK, TN] = checkk[:, newaxis];
|
||||||
acc += a @ b;
|
acc += a @b;
|
||||||
#if (IS_TK_DIV_K==1)
|
#if (IS_TK_DIV_K == 1)
|
||||||
a = *?(checka)pa;
|
a = *? (checka)pa;
|
||||||
b = *?(checkb)pb;
|
b = *? (checkb)pb;
|
||||||
#else
|
#else
|
||||||
a = checka ? *pa : 0;
|
a = checka ? *pa : 0;
|
||||||
b = checkb ? *pb : 0;
|
b = checkb ? *pb : 0;
|
||||||
#endif
|
#endif
|
||||||
pa += TK * STRIDE_AK;
|
pa += TK * STRIDE_AK;
|
||||||
pb += TK * STRIDE_BK;
|
pb += TK * STRIDE_BK;
|
||||||
}
|
}
|
||||||
acc = acc * alpha;
|
acc = acc * alpha;
|
||||||
TYPE c[TM, TN] = acc;
|
TYPE c[TM, TN] = acc;
|
||||||
|
|
||||||
// epilogue
|
// epilogue
|
||||||
int rcm[TM] = pidm * TM + 0 ... TM;
|
int rcm[TM] = pidm * TM + 0 ... TM;
|
||||||
int rcn[TN] = pidn * TN + 0 ... TN;
|
int rcn[TN] = pidn * TN + 0 ... TN;
|
||||||
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
|
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
|
||||||
TYPE* pc[TM, TN] = C + offc;
|
TYPE *pc[TM, TN] = C + offc;
|
||||||
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
|
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
|
||||||
#if (TZ==1)
|
#if (SPLITK == 1)
|
||||||
*?(checkc) pc = c;
|
*? (checkc)pc = c;
|
||||||
#else
|
#else
|
||||||
// accumulate partial result using spin-locks
|
// accumulate partial result using spin-locks
|
||||||
int *plock = locks + pid;
|
int *plock = locks + pid;
|
||||||
int *pcount = plock + get_num_programs(0);
|
int *pcount = plock + get_num_programs(0);
|
||||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
|
||||||
int count = *pcount;
|
;
|
||||||
if(count == 0)
|
int count = *pcount;
|
||||||
*?(checkc) pc = c;
|
if (count == 0)
|
||||||
else
|
*? (checkc)pc = c;
|
||||||
*?(checkc) pc = c + *?(checkc)pc;
|
else
|
||||||
atomic_xchg(pcount, (count + 1) % TZ);
|
*? (checkc)pc = c + *? (checkc)pc;
|
||||||
atomic_xchg(plock, 0);
|
atomic_xchg(pcount, (count + 1) % SPLITK);
|
||||||
|
atomic_xchg(plock, 0);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
@@ -3,29 +3,32 @@ import triton
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
|
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
|
||||||
|
|
||||||
_DEFAULT_CONFIGS = [
|
_DEFAULT_CONFIGS = [
|
||||||
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
|
||||||
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
|
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||||
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
|
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
|
||||||
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||||
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||||
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
|
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
|
||||||
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
|
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||||
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
|
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
|
||||||
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||||
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
|
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
|
||||||
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||||
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
|
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
|
||||||
]
|
]
|
||||||
_CONFIGS = _DEFAULT_CONFIGS
|
_CONFIGS = _DEFAULT_CONFIGS
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def largest_pow2_divisor(N):
|
def largest_pow2_divisor(N):
|
||||||
if N % 8 == 0: return 8
|
if N % 8 == 0:
|
||||||
if N % 4 == 0: return 4
|
return 8
|
||||||
if N % 2 == 0: return 2
|
if N % 4 == 0:
|
||||||
|
return 4
|
||||||
|
if N % 2 == 0:
|
||||||
|
return 2
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
_locks = dict()
|
_locks = dict()
|
||||||
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
|
|||||||
K, N = b.shape
|
K, N = b.shape
|
||||||
c = torch.empty((M, N), dtype=dtype, device=device)
|
c = torch.empty((M, N), dtype=dtype, device=device)
|
||||||
# handle non-contiguous inputs if necessary
|
# handle non-contiguous inputs if necessary
|
||||||
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
|
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||||
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
|
a = a.contiguous()
|
||||||
|
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||||
|
b = b.contiguous()
|
||||||
# kernel hash
|
# kernel hash
|
||||||
is_a_row = a.stride(1) == 1
|
is_a_row = a.stride(1) == 1
|
||||||
is_b_row = b.stride(1) == 1
|
is_b_row = b.stride(1) == 1
|
||||||
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
|
|||||||
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
|
||||||
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
|
||||||
is_tk_div_k = K % 64 == 0
|
is_tk_div_k = K % 64 == 0
|
||||||
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
|
key = (
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
is_a_row,
|
||||||
|
is_b_row,
|
||||||
|
lda_pow2_div,
|
||||||
|
ldb_pow2_div,
|
||||||
|
ldc_pow2_div,
|
||||||
|
is_tk_div_k,
|
||||||
|
)
|
||||||
if key not in _matmul._kernels:
|
if key not in _matmul._kernels:
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
|
"TYPE": dtype,
|
||||||
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
|
"STRIDE_AM": "lda" if is_a_row else "1",
|
||||||
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
|
"STRIDE_AK": "1" if is_a_row else "lda",
|
||||||
int(is_tk_div_k)
|
"STRIDE_BK": "ldb" if is_b_row else "1",
|
||||||
|
"STRIDE_BN": "1" if is_b_row else "ldb",
|
||||||
|
"LDA_POW2_DIV": lda_pow2_div,
|
||||||
|
"LDB_POW2_DIV": ldb_pow2_div,
|
||||||
|
"LDC_POW2_DIV": ldc_pow2_div,
|
||||||
|
"IS_TK_DIV_K": int(is_tk_div_k),
|
||||||
}
|
}
|
||||||
_matmul._kernels[key] = triton.kernel(_matmul.src,
|
_matmul._kernels[key] = triton.kernel(
|
||||||
device,
|
_matmul.src,
|
||||||
defines=defines,
|
device,
|
||||||
autotune_vals=_matmul._CONFIGS,
|
defines=defines,
|
||||||
autotune_key=['M', 'N', 'K'])
|
autotune_vals=_matmul._CONFIGS,
|
||||||
|
autotune_key=["M", "N", "K"],
|
||||||
|
)
|
||||||
kernel = _matmul._kernels[key]
|
kernel = _matmul._kernels[key]
|
||||||
# # locks for split-k
|
# # locks for split-k
|
||||||
if device not in _matmul._locks:
|
if device not in _matmul._locks:
|
||||||
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
||||||
locks = _matmul._locks[device]
|
locks = _matmul._locks[device]
|
||||||
# enqueue
|
# enqueue
|
||||||
alpha = 1.
|
alpha = 1.0
|
||||||
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
|
args = [
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
|
a.data_ptr(),
|
||||||
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
alpha,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
locks.data_ptr(),
|
||||||
|
]
|
||||||
|
grid = lambda opt: [
|
||||||
|
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
|
||||||
|
1,
|
||||||
|
opt.SPLITK,
|
||||||
|
]
|
||||||
kernel(*args, grid=grid)
|
kernel(*args, grid=grid)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
@@ -1,21 +1,33 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
ret = torch.empty(
|
||||||
|
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||||
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
|
ret[:, idx, :, :] = x[
|
||||||
|
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
|
||||||
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def mask_tensor(x, mask, block, value=0):
|
def mask_tensor(x, mask, block, value=0):
|
||||||
ret = x.clone()
|
ret = x.clone()
|
||||||
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
|
||||||
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
|
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def allclose(x, y):
|
def allclose(x, y):
|
||||||
assert x.dtype == y.dtype
|
assert x.dtype == y.dtype
|
||||||
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
|
diff = abs(x - y)
|
||||||
return torch.allclose(x, y, atol=atol, rtol=rtol)
|
x_max = torch.max(x)
|
||||||
|
y_max = torch.max(y)
|
||||||
|
tol = 1e-2
|
||||||
|
err = torch.max(diff) / torch.max(x_max, y_max)
|
||||||
|
return err < tol
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn, flops=0, warmup=10, rep=50):
|
def do_bench(fn, flops=0, warmup=10, rep=50):
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
|
|||||||
time_ms = start_event.elapsed_time(end_event) / rep
|
time_ms = start_event.elapsed_time(end_event) / rep
|
||||||
return time_ms
|
return time_ms
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
|
def __init__(
|
||||||
|
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
|
||||||
|
):
|
||||||
self.x_names = x_names
|
self.x_names = x_names
|
||||||
self.x_vals = x_vals
|
self.x_vals = x_vals
|
||||||
self.y_name = y_name
|
self.y_name = y_name
|
||||||
@@ -44,6 +59,7 @@ class Benchmark:
|
|||||||
self.plot_name = plot_name
|
self.plot_name = plot_name
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
|
|
||||||
class Mark:
|
class Mark:
|
||||||
def __init__(self, fn, benchmarks):
|
def __init__(self, fn, benchmarks):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
@@ -53,26 +69,31 @@ class Mark:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
|
|
||||||
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
||||||
for x in bench.x_vals:
|
for x in bench.x_vals:
|
||||||
x_args = {x_name: x for x_name in bench.x_names}
|
x_args = {x_name: x for x_name in bench.x_names}
|
||||||
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
|
row = [
|
||||||
|
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
|
||||||
|
for y in bench.y_vals
|
||||||
|
]
|
||||||
df.loc[len(df)] = [x] + row
|
df.loc[len(df)] = [x] + row
|
||||||
if with_plot and bench.plot_name:
|
if with_plot and bench.plot_name:
|
||||||
xlabel = ' = '.join(bench.x_names)
|
xlabel = " = ".join(bench.x_names)
|
||||||
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
||||||
plot.set_xlabel(xlabel)
|
plot.set_xlabel(xlabel)
|
||||||
plot.set_ylabel(bench.ylabel)
|
plot.set_ylabel(bench.ylabel)
|
||||||
plot.set_title(bench.plot_name)
|
plot.set_title(bench.plot_name)
|
||||||
plot.set_xscale('log' if bench.loglog else 'linear')
|
plot.set_xscale("log" if bench.loglog else "linear")
|
||||||
plot.set_yscale('log' if bench.loglog else 'linear')
|
plot.set_yscale("log" if bench.loglog else "linear")
|
||||||
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
|
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
||||||
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
|
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
||||||
|
|
||||||
def run(self, result_path, with_plot):
|
def run(self, result_path, with_plot):
|
||||||
for bench in self.benchmarks:
|
for bench in self.benchmarks:
|
||||||
self._run(bench, result_path, with_plot)
|
self._run(bench, result_path, with_plot)
|
||||||
|
|
||||||
|
|
||||||
def perf_report(benchmarks):
|
def perf_report(benchmarks):
|
||||||
wrapper = lambda fn: Mark(fn, benchmarks)
|
wrapper = lambda fn: Mark(fn, benchmarks)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||||
pa += TK * STRIDE_AK;
|
|
||||||
pb += TK * STRIDE_BK;
|
|
||||||
// reduction loop
|
// reduction loop
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
for(int k = K; k > 0; k -= TK){
|
for(int k = K; k > 0; k -= TK){
|
||||||
bool checka[TM, TK] = k > TK;
|
bool checka[TM, TK] = k > TK;
|
||||||
bool checkb[TK, TN] = k > TK;
|
bool checkb[TK, TN] = k > TK;
|
||||||
acc += a @ b;
|
|
||||||
a = *?(checka)pa;
|
|
||||||
b = *?(checkb)pb;
|
|
||||||
pa += TK * STRIDE_AK;
|
pa += TK * STRIDE_AK;
|
||||||
pb += TK * STRIDE_BK;
|
pb += TK * STRIDE_BK;
|
||||||
|
TYPE anext[TM, TK] = *?(checka)pa;
|
||||||
|
TYPE bnext[TK, TN] = *?(checkb)pb;
|
||||||
|
acc += a @ b;
|
||||||
|
a = anext;
|
||||||
|
b = bnext;
|
||||||
|
// __debug_barrier();
|
||||||
}
|
}
|
||||||
acc = acc * alpha;
|
acc = acc * alpha;
|
||||||
TYPE c[TM, TN] = acc;
|
TYPE c[TM, TN] = acc;
|
||||||
@@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
|
|||||||
opt.defines["TYPE"] = ty;
|
opt.defines["TYPE"] = ty;
|
||||||
opt.defines["TM"] = "128";
|
opt.defines["TM"] = "128";
|
||||||
opt.defines["TN"] = "128";
|
opt.defines["TN"] = "128";
|
||||||
opt.defines["TK"] = "32" ;
|
opt.defines["TK"] = "64" ;
|
||||||
opt.defines["TZ"] = "1";
|
opt.defines["TZ"] = "1";
|
||||||
opt.num_warps = 4;
|
opt.num_warps = 4;
|
||||||
// arguments
|
// arguments
|
||||||
|
Reference in New Issue
Block a user