[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,6 +2,9 @@
#define TDL_INCLUDE_CODEGEN_BARRIERS_H #define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector> #include <vector>
#include <map>
#include <list>
#include <set>
namespace triton { namespace triton {
@@ -9,6 +12,7 @@ namespace ir {
class module; class module;
class basic_block; class basic_block;
class instruction; class instruction;
class masked_load_async_inst;
class value; class value;
class builder; class builder;
} }
@@ -29,18 +33,15 @@ namespace transform{
class membar { class membar {
private: private:
typedef std::pair<unsigned, unsigned> interval_t; typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t; typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private: private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals); bool intersect(const val_set_t &X, const val_set_t &Y);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder); int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect(const interval_vec_t &X, interval_t x); val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y); void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
void add_reference(ir::value *v, interval_vec_t &res); std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
public: public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -16,6 +16,10 @@ namespace ir {
} }
namespace codegen{ namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{ namespace transform{
class peephole { class peephole {
@@ -33,11 +37,12 @@ private:
private: private:
public: public:
peephole(target* tgt): tgt_(tgt) {} peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
target* tgt_; target* tgt_;
analysis::layouts* layouts_;
}; };

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
void run(ir::module &module);
private:
bool has_copy_async_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend); static driver::stream* create(backend_t backend);
// methods // methods
virtual void synchronize() = 0; virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0; virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers // template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public: public:
host_stream(); host_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership); cu_stream(CUstream str, bool take_ownership);
cu_stream(); cu_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };

View File

@@ -35,6 +35,7 @@ public:
basic_block* get_insert_block() { return block_; } basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}
// Constants // Constants
value *get_int1(bool val);
value *get_int32(int32_t val); value *get_int32(int32_t val);
value *get_int64(int64_t val); value *get_int64(int64_t val);
// Types // Types
@@ -149,7 +150,7 @@ public:
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = ""); value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = ""); value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = ""); value *create_barrier(const std::string &name = "");
value *create_async_wait(); value *create_async_wait(int N);
private: private:
context &ctx_; context &ctx_;

View File

@@ -92,6 +92,7 @@ private:
public: public:
void set_incoming_value(unsigned i, value *v); void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block); void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); } value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); } unsigned get_num_incoming() { return get_num_operands(); }
@@ -803,14 +804,18 @@ public:
class async_wait_inst: public instruction{ class async_wait_inst: public instruction{
private: private:
async_wait_inst(context &ctx, const std::string &name, instruction *next); async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; } std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst) _TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst) _TRITON_DEFINE_ACCEPT(async_wait_inst)
public: public:
static async_wait_inst* create(context &ctx, const std::string &name = "", static async_wait_inst* create(context &ctx, int N,
instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
private:
int N_;
}; };
// On NVIDIA, implementation is such that // On NVIDIA, implementation is such that

View File

@@ -98,6 +98,8 @@ private:
std::shared_ptr<ir::module> ir_; std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_; std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_; std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
}; };
class function { class function {

View File

@@ -30,11 +30,8 @@ private:
high_resolution_clock::time_point _start; high_resolution_clock::time_point _start;
}; };
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false) inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{ {
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr; timer tmr;
std::vector<size_t> times; std::vector<size_t> times;
double total_time = 0; double total_time = 0;

View File

@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
if(rhs_cst_info[d].num_cst) if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d]; rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue); result[d] = std::max(lvalue, rvalue);
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
} }
return add_to_cache(x, result, max_contiguous_); return add_to_cache(x, result, max_contiguous_);
} }
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) { // ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v)) // if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0] // std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// }); // });
} }

View File

@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; // std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl; // std::cout << order_[0] << " " << order_[1] << std::endl;
} }
if(is_recoalesce){
if(ptr.size() > 0){
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// if(order_[0] == 0)
// exit(1);
}
}
// std::cout << "---" << std::endl;
} }
int data_layout::find_axis(int to_find) const { int data_layout::find_axis(int to_find) const {
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr; ir::value *ptr = nullptr;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::io_inst*>(usr)) if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
ptr = st->get_pointer_operand(); if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
unsigned i = order_[0]; unsigned i = order_[0];
int contiguous = 1; int contiguous = 1;
if(ptr){ if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits); contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
} }
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));

View File

@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
} }
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned vector = 1; unsigned in_vec = 1;
ir::value *ptrs = x->get_pointer_operand(); ir::value *arg = x->get_pointer_operand();
ir::value *msks = x->get_mask_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline(); analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order(); auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order(); auto in_order = in_layout->get_order();
// tiles // tiles
if(out_order == in_order) if(out_order == in_order)
vector = in_layout->nts(in_order[0]); in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
// //
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; int per_phase = swizzle_->get_per_phase(out_layout);
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1); int max_phase = swizzle_->get_max_phase(out_layout);
Value *max_phase = i32(8 / num_per_phase);
// //
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes(); auto shapes = x->get_type()->get_tile_shapes();
// BasicBlock* CurrBB = builder_->GetInsertBlock();
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1); std::map<std::pair<int, int>, Value*> tmp;
std::vector<Value*> shared; std::vector<std::pair<Value*, int>> shared;
for(size_t i = 0; i < n_shared; i++){ for(int i = 0; i < idxs_.at(arg).size(); i++){
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld); unsigned id = i / min_vec;
// phase
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
phase = urem(phase, max_phase);
// off
Value* off_0 = idx[in_order[0]];
off_0 = udiv(off_0, i32(vector));
off_0 = xor_(off_0, phase);
off_0 = mul(off_0 , i32(vector));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off = add(off_0, off_1);
//
shared.push_back(gep(shmems_[x], {off}));
}
//
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
auto idx = idxs_[ptrs][i];
// input ptr info // input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]); int id_0 = id % (in_ld/min_vec);
Value *in_base = in_gep->getPointerOperand(); int id_1 = id / (in_ld/min_vec);
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector; int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
Value* out_base = shared[(i / per_thread_ld) % n_shared]; int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]); int off = (off_1*shapes[in_order[0]] + off_0);
int out_off_1 = i % per_thread_ld; std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2; if(tmp.find(key) == tmp.end()){
// asm if(CurrBB != FirstBB)
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false); builder_->SetInsertPoint(FirstBB->getTerminator());
std::string mod = (vector*2 == 16) ? ".cg" : ".ca"; indices_t idx = idxs_.at(arg).at(key.first*in_ld);
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";"; Value* phase = udiv(idx[in_order[1]], i32(per_phase));
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true); phase = urem(phase, i32(max_phase));
call(iasm, {vals_[msks][idx], out_base, in_base}); Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
}
shared.push_back({tmp[key], off});
} }
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
} }
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes(); auto shapes = cts->get_type()->get_tile_shapes();
// default implementation // store to shared
Value *current = nullptr; Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs; std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){ for(int i = 0; i < idxs_.at(arg).size(); i++){
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier(); add_barrier();
} }
void generator::visit_async_wait_inst(ir::async_wait_inst*) { void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_all;"; std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm); call(iasm);
add_barrier();
} }
void generator::visit_make_range_dyn(ir::make_range_dyn* x) { void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(unsigned alloc_size = alloc_->allocated_size()){ if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_); Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_); Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4); ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3); Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array = GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage, new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty); shmem_ = bit_cast(sh_mem_array, ptr_ty);
} }

View File

@@ -15,114 +15,105 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
auto *i = dynamic_cast<ir::instruction*>(v); if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
if(!i) analysis::shared_layout* layout = layouts_->get(v)->to_shared();
return; analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(!i->get_type()->is_tile_ty()) if(info)
return; return group_of(info->first, async_write);
analysis::shared_layout* layout = layouts_->get(v)->to_shared(); std::vector<int> groups(phi->get_num_operands());
if(!layout) std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return; return *std::max_element(groups.begin(), groups.end());
if(alloc_->has_offset(layout)){ }
unsigned offset = alloc_->offset(layout); else{
res.push_back(interval_t(offset, offset + layout->get_size())); auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
} }
} }
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops()) membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
add_reference(op, res); val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
continue;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
continue;
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
ret.insert(b);
}
}
return ret;
} }
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){ void membar::transfer(ir::basic_block *block,
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i)) val_vec_t& async_write,
add_reference(i, res); val_set_t& sync_write,
} val_set_t& sync_read,
std::set<ir::value*>& safe_war,
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) { bool& inserted, ir::builder& builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) { ir::basic_block::inst_list_t instructions = block->get_inst_list();
std::set<ir::value*> incoming; for(ir::instruction *i: instructions){
for(unsigned n = 0; n < phi->get_num_incoming(); n++){ if(dynamic_cast<ir::phi_node*>(i))
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n)); continue;
assert(inc_val); if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
if(incoming.insert(inc_val).second){ dynamic_cast<ir::masked_load_async_inst*>(i)){
ir::basic_block *block = inc_val->get_parent(); async_write.push_back(i);
builder.set_insert_point(block->get_inst_list().back()); }
if(type.first) if(dynamic_cast<ir::copy_to_shared_inst*>(i))
builder.create_async_wait(); sync_write.insert(i);
if(type.second) ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
builder.create_barrier(); ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
} }
} }
} // RAW, WAR
else { if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
builder.set_insert_point(instr); builder.set_insert_point(i);
builder.create_barrier(); barrier = (ir::barrier_inst*)builder.create_barrier();
} inserted = true;
}
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
} }
// record hazards // update state of asynchronous copies
if(read_after_write || write_after_read) { if(async_wait){
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);}; int N = async_write.size() - async_wait->get_N();
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);}; async_write.erase(async_write.begin(), async_write.begin() + N);
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear();
new_read_from.clear();
to_sync.clear();
} }
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to)); // all the copy_to_shared and read from shared are synchronized after barrier
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from)); if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
} }
return std::make_pair(new_written_to, new_read_from);
} }
void membar::run(ir::module &mod) { void membar::run(ir::module &mod) {
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to; std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, interval_vec_t> read_from; std::map<ir::basic_block*, val_set_t> sync_writes;
std::vector<ir::instruction*> to_sync; std::map<ir::basic_block*, val_set_t> sync_reads;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs; std::list<ir::value *> pipelined;
size_t n_inserted_im1 = 0; bool inserted;
bool done = false;
do{ do{
inserted = false;
// find barrier location // find barrier location
for(ir::basic_block *block: rpo){ for(ir::basic_block *block: rpo){
// written to // join inputs
std::vector<interval_vec_t> pred_written_to; val_vec_t async_write;
for(ir::basic_block* pred: block->get_predecessors()) val_set_t sync_write;
pred_written_to.push_back(written_to[pred]); val_set_t sync_read;
// read from val_set_t tmp;
std::vector<interval_vec_t> pred_read_from; for(ir::basic_block* pred: block->get_predecessors()){
for(ir::basic_block* pred: block->get_predecessors()) for(ir::value* v: async_writes[pred])
pred_read_from.push_back(read_from[pred]); if(tmp.insert(v).second)
// apply transfer function async_write.push_back(v);
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync); sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
written_to[block] = result.first; sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
read_from[block] = result.second; }
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
} }
size_t n_inserted_i = insert_locs.size(); }while(inserted);
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
}
} }
} }

View File

@@ -1,7 +1,9 @@
#include <algorithm> #include <algorithm>
#include <iostream>
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton { namespace triton {
namespace codegen{ namespace codegen{
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
ir::value *ptr = ld->get_pointer_operand(); ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand(); ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand(); ir::value *val = ld->get_false_value_operand();
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
copy_to_shared->replace_all_uses_with(new_load); int nts = layout->nts(layout->get_order()[0]);
return true; int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
} }
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
bool was_modified = false; bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder); was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder); // was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80) if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder); was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified) if(was_modified)
seen.insert(i); seen.insert(i);
} }

View File

@@ -0,0 +1,116 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
inline bool is_cst(ir::value *x) { inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x)) if(dynamic_cast<ir::constant*>(x))
return true; return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x)) if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0)); return is_cst(v->get_operand(0));
return false; return false;

View File

@@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) { cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name); dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED); dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
// properties
int shared_total, shared_optin, shared_static;
int n_spills, n_reg;
CUdevice dev;
dispatch::cuCtxGetDevice(&dev);
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
} }
} }

View File

@@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
void cu_module::init_from_ptx(const std::string& ptx) { void cu_module::init_from_ptx(const std::string& ptx) {
// JIT compile source-code // JIT compile source-code
// std::cout << ptx << std::endl;
try{ try{
// // compile ptx with ptxas // // compile ptx with ptxas

View File

@@ -76,7 +76,7 @@ void host_stream::synchronize() {
hst_->args.clear(); hst_->args.clear();
} }
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) { void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
auto hst = kernel->module()->hst(); auto hst = kernel->module()->hst();
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
char* params = new char[args_size]; char* params = new char[args_size];
@@ -113,13 +113,13 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_); dispatch::cuStreamSynchronize(*cu_);
} }
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) { void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
void *config[] = { void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END CU_LAUNCH_PARAM_END
}; };
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config); dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
} }
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
// convenience functions // convenience functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val) value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);} { return constant_int::get(type::get_int32_ty(ctx_), val);}
@@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name)); return insert(barrier_inst::create(ctx_, name));
} }
value *builder::create_async_wait() { value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_)); return insert(async_wait_inst::create(ctx_, N));
} }
} }

View File

@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
blocks_.reserve(num_reserved); blocks_.reserve(num_reserved);
} }
value* phi_node::get_value_for_block(basic_block * block) {
auto it = std::find(blocks_.begin(), blocks_.end(), block);
size_t n = std::distance(blocks_.begin(), it);
return get_incoming_value(n);
}
// Set incoming value // Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){ void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!"); assert(v && "PHI node got a null value!");
@@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next); return new barrier_inst(ctx, name, next);
} }
async_wait_inst::async_wait_inst(context &ctx, const std::string &name, async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
instruction *next) : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) { async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next); return new async_wait_inst(ctx, N, name, next);
} }

View File

@@ -15,10 +15,10 @@
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h" #include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h" #include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/runtime/function.h" #include "triton/runtime/function.h"
#include "triton/lang/cpp.h" #include "triton/lang/cpp.h"
#include "triton/lang/parser.h" #include "triton/lang/parser.h"
@@ -149,6 +149,7 @@ void kernel::init_ker(){
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async); codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate; codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts); codegen::analysis::liveness liveness(&layouts);
@@ -156,19 +157,24 @@ void kernel::init_ker(){
codegen::analysis::allocation allocation(&liveness); codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get()); codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate; codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes // run passes
dce.run(*ir_); dce.run(*ir_);
pipeline.run(*ir_);
dce.run(*ir_);
disassociate.run(*ir_); disassociate.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_); peephole.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_);
if(target->is_gpu()) if(target->is_gpu())
cts.run(*ir_); cts.run(*ir_);
align.run(*ir_);
axes.run(*ir_); axes.run(*ir_);
layouts.run(*ir_); layouts.run(*ir_);
coalesce.run(*ir_); coalesce.run(*ir_);
@@ -179,6 +185,11 @@ void kernel::init_ker(){
reassociate.run(*ir_); reassociate.run(*ir_);
cts.run(*ir_); cts.run(*ir_);
} }
dce.run(*ir_);
// ir::print(*ir_, std::cout);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_); peephole.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_); align.run(*ir_);
@@ -187,8 +198,9 @@ void kernel::init_ker(){
swizzle.run(*ir_); swizzle.run(*ir_);
liveness.run(*ir_); liveness.run(*ir_);
allocation.run(*ir_); allocation.run(*ir_);
if(allocation.allocated_size() > dev_->max_shared_memory()) shared_mem_ = allocation.allocated_size();
throw exception::out_of_shared_memory(); // if(allocation.allocated_size() > dev_->max_shared_memory())
// throw exception::out_of_shared_memory();
barriers.run(*ir_); barriers.run(*ir_);
isel.visit(*ir_, *llvm); isel.visit(*ir_, *llvm);
//if(res->spilled() > 256) //if(res->spilled() > 256)
@@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++) for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1; grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue // enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size); stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
} }
std::string kernel::get_asm(asm_mode_t mode) { std::string kernel::get_asm(asm_mode_t mode) {
@@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
while(grid.size() < 3) while(grid.size() < 3)
grid.push_back(1); grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, true); stream, 5, 20);
ret = (ts < best_ts) ? current : ret; ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts); best_ts = std::min(ts, best_ts);
} }

View File

@@ -2,58 +2,74 @@ import triton
import torch import torch
# square benchmarks # square benchmarks
nt = {False: 'n', True: 't'} nt = {False: "n", True: "t"}
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names=["M", "N", "K"],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[512 * i for i in range(1, 16)],
y_name = 'provider', y_name="provider",
y_vals = ['torch', 'triton', 'cutlass'], y_vals=["torch", "triton", "cutlass"],
y_lines = ['Torch', 'Triton', 'CUTLASS'], y_lines=["Torch", "Triton", "CUTLASS"],
ylabel = 'TFLOPS', ylabel="TFLOPS",
loglog = False, loglog=False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}', plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args = {'AT': False, 'BT': False, 'dtype': torch.float16} args={"AT": AT, "BT": BT, "dtype": torch.float16},
)\ ) for AT in [False, True] for BT in [False, True]
for AT in [False, True] for BT in [False, True]
] ]
@triton.testing.perf_report(square_confs) @triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
if AT: a = a.t() b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if BT: b = b.t() if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K num_flops = 2 * M * N * K
if provider == 'torch': if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9 torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops return torch_tflops
if provider == 'triton': if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9 triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess import subprocess
import tempfile import tempfile
import pandas as pd import pandas as pd
# run program specified by CUTLASS_PROFILER env variable # run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row' layout_a = "column" if AT else "row"
layout_b = 'column' if BT else 'row' layout_b = "column" if BT else "row"
# create temporary file name # create temporary file name
fd, fname = tempfile.mkstemp() fd, fname = tempfile.mkstemp()
# run program and gets its output # run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ cmd = [
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \ os.environ["CUTLASS_PROFILER"],
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false'] f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--verbose=false",
]
# run cmd # run cmd
subprocess.run(cmd, stdout=subprocess.PIPE) subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output # read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv') df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c['GFLOPs']) / 1e3 cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops return cutlass_tflops
return None return None
if __name__ == '__main__': if __name__ == "__main__":
bench_op.run() bench_op.run()

View File

@@ -15,21 +15,21 @@ import distutils.spawn
import torch import torch
def find_llvm(): def find_llvm():
versions = ['-10', '-10.0', ''] versions = ["-10", "-10.0", ""]
supported = ['llvm-config{v}'.format(v=v) for v in versions] supported = ["llvm-config{v}".format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None] paths = [p for p in paths if p is not None]
if paths: if paths:
return paths[0] return paths[0]
config = distutils.spawn.find_executable('llvm-config') config = distutils.spawn.find_executable("llvm-config")
instructions = 'Please install llvm-10-dev' instructions = "Please install llvm-10-dev"
if config is None: if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions) raise RuntimeError("Could not find llvm-config. " + instructions)
version = os.popen('{config} --version'.format(config=config)).read() version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions) raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''): def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[]) Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir) self.sourcedir = os.path.abspath(sourcedir)
self.path = path self.path = path
@@ -37,84 +37,84 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext): class CMakeBuild(build_ext):
def run(self): def run(self):
try: try:
out = subprocess.check_output(['cmake', '--version']) out = subprocess.check_output(["cmake", "--version"])
except OSError: except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " + raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions)) ", ".join(e.name for e in self.extensions))
if platform.system() == "Windows": if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
if cmake_version < '3.1.0': if cmake_version < "3.1.0":
raise RuntimeError("CMake >= 3.1.0 is required on Windows") raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions: for ext in self.extensions:
self.build_extension(ext) self.build_extension(ext)
def build_extension(self, ext): def build_extension(self, ext):
#self.debug = True # self.debug = True
self.debug = False
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories # python directories
python_include_dirs = distutils.sysconfig.get_python_inc() python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
torch_include_dirs = include_paths(True) torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True) torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = [ cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
'-DBUILD_TUTORIALS=OFF', "-DBUILD_TUTORIALS=OFF",
'-DBUILD_PYTHON_MODULE=ON', "-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), "-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), "-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi, "-DTORCH_CXX11_ABI=" + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', "-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
'-DLLVM_CONFIG=' + find_llvm() "-DLLVM_CONFIG=" + find_llvm(),
] ]
# configuration # configuration
cfg = 'Debug' if self.debug else 'Release' cfg = "Debug" if self.debug else "Release"
cfg = 'Release' build_args = ["--config", cfg]
build_args = ['--config', cfg]
if platform.system() == "Windows": if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)] cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32: if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64'] cmake_args += ["-A", "x64"]
build_args += ['--', '/m'] build_args += ["--", "/m"]
else: else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ['--', '-j4'] build_args += ["--", "-j4"]
env = os.environ.copy() env = os.environ.copy()
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup( setup(
name='triton', name="triton",
version='1.0.0', version="1.0.0",
author='Philippe Tillet', author="Philippe Tillet",
author_email='phil@openai.com', author_email="phil@openai.com",
description='A language and compiler for custom Deep Learning operations', description="A language and compiler for custom Deep Learning operations",
long_description='', long_description="",
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'], packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=['numpy', 'torch'], install_requires=["numpy", "torch"],
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']}, package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True, include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')], ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={'build_ext': CMakeBuild}, cmdclass={"build_ext": CMakeBuild},
zip_safe=False, zip_safe=False,
# for PyPI # for PyPI
keywords=['Compiler', 'Deep Learning'], keywords=["Compiler", "Deep Learning"],
url='https://github.com/ptillet/triton/', url="https://github.com/ptillet/triton/",
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers "Intended Audience :: Developers", # Define that your audience are developers
'Topic :: Software Development :: Build Tools', "Topic :: Software Development :: Build Tools",
'License :: OSI Approved :: MIT License', # Again, pick a license "License :: OSI Approved :: MIT License", # Again, pick a license
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
], ],
) )

View File

@@ -2,29 +2,17 @@ import torch
import triton import triton
import pytest import pytest
@pytest.mark.parametrize( @pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK", "MODE, TRANS_A, TRANS_B, BLOCK",
[ [(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
(mode, at, bt, block) for block in [16, 32, 64]],
for mode in ["sdd", "dsd", "dds"]
for at in [False, True]
for bt in [False, True]
for block in [16, 32, 64]
],
) )
def test_matmul( def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
# create inputs # create inputs
a = torch.randn( a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda" b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
shape = { shape = {
"sdd": (M, N), "sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]), "dsd": (a.shape[2], a.shape[3]),
@@ -32,9 +20,7 @@ def test_matmul(
}[MODE] }[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result # triton result
op = triton.ops.blocksparse.matmul( op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb) rc = op(ra, rb)
@@ -49,7 +35,6 @@ def test_matmul(
# compare # compare
assert triton.testing.allclose(rc, tc) assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"BLOCK, WIDTH", "BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]], [(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# create inputs # create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint( at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda" kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
)
kp_mask = torch.randint(
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
kp_mask[kp_mask == 1.0] = float("-inf") kp_mask[kp_mask == 1.0] = float("-inf")
# triton result # triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK) op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# compare # compare
assert triton.testing.allclose(ry, ty) assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd( def test_attention_fwd_bwd(
input_scale=1.0, input_scale=1.0,
tol=2e-2, tol=2e-2,
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
# inputs # inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64) qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [ qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
.to(dtype)
.cuda()
for _ in range(3)
] ]
attn_mask = torch.tril( attn_mask = torch.tril(
torch.ones( torch.ones(
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
query.retain_grad() query.retain_grad()
key.retain_grad() key.retain_grad()
value.retain_grad() value.retain_grad()
attn_out = triton_attention( attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
# ad hoc loss # ad hoc loss
loss = (attn_out ** 2).mean() loss = (attn_out**2).mean()
loss.backward() loss.backward()
grads = [query.grad, key.grad, value.grad] grads = [query.grad, key.grad, value.grad]
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
probs = torch.softmax(scores, dim=-1) probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss # ad hoc loss
torch_loss = (torch_attn_out ** 2).mean() torch_loss = (torch_attn_out**2).mean()
torch_loss.backward() torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison # comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol) torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads): for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention( def triton_attention(
layout, layout,
block: int, block: int,
@@ -168,12 +142,8 @@ def triton_attention(
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
): ):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
layout, block, "sdd", trans_a=False, trans_b=True sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_softmax = triton.ops.blocksparse.softmax( sparse_softmax = triton.ops.blocksparse.softmax(
layout, layout,
block, block,

View File

@@ -4,7 +4,7 @@ import triton
import torch import torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", "TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[ itertools.chain(*[
[ [
# 1 warp # 1 warp
@@ -17,14 +17,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp # # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp # # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -40,24 +40,28 @@ import torch
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input # variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE) (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True] ] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
])) ]),
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): )
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE] def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0) torch.manual_seed(0)
triton.ops._matmul._kernels = dict() triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
if M is None: M = TM if M is None:
if N is None: N = TN M = TM
if K is None: K = TK * TZ if N is None:
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 N = TN
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 if K is None:
K = TK * SPLITK
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a a = a.t() if AT else a
b = b.t() if BT else b b = b.t() if BT else b
th_c = torch.matmul(a, b) th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b) tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c) assert triton.testing.allclose(th_c, tt_c)

View File

@@ -1,198 +1,199 @@
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16), __global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16), TYPE *B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16), TYPE *C __noalias __aligned(16),
int lda __multipleof(8), int lda __multipleof(8),
int ldb __multipleof(8), int ldb __multipleof(8),
int ldc __multipleof(8), int ldc __multipleof(8),
long stride_za __multipleof(8), long stride_za __multipleof(8),
long stride_zb __multipleof(8), long stride_zb __multipleof(8),
long stride_zc __multipleof(8), long stride_zc __multipleof(8),
long stride_ha __multipleof(8), long stride_ha __multipleof(8),
long stride_hb __multipleof(8), long stride_hb __multipleof(8),
long stride_hc __multipleof(8), long stride_hc __multipleof(8),
int DS0, int DS1, int DS0, int DS1,
int SDD_K __multipleof(16), int SDD_K __multipleof(16),
int SDD_off_width, int SDD_off_width,
int* lut, int* locks, int nlocks) { int *lut, int *locks, int nlocks) {
/* ---------------- */ /* ---------------- */
/* Prologue */ /* Prologue */
/* ---------------- */ /* ---------------- */
// program ids // program ids
int pid0 = get_program_id(0); int pid0 = get_program_id(0);
int pid1 = get_program_id(1); int pid1 = get_program_id(1);
int pidz = get_program_id(2); int pidz = get_program_id(2);
#ifdef SDD #ifdef SDD
// load LUT header // load LUT header
pid1 = pid1 + SDD_off_width; pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK; int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK; int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4; int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
int offlutn[TN] = blockidn*4; int offlutn[TN] = blockidn * 4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4; int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
int z = *(header + 0); int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm); int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn); int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ; int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0); int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1; int offka = pid0 * AS1;
int offkb = pid0 * AS1; int offkb = pid0 * AS1;
int offmc = 0; int offmc = 0;
int offnc = 0; int offnc = 0;
int offpa = 0; int offpa = 0;
int offpb = 0; int offpb = 0;
int maxid = TZ; int maxid = TZ;
int offhc = 0; int offhc = 0;
int offha = z; int offha = z;
int offhb = z; int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK); int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK); int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
#else #else
// load LUT header // load LUT header
int *header = lut + pid0 * 6; int *header = lut + pid0 * 6;
int offset = *(header + 0); int offset = *(header + 0);
int AS1 = *(header + 1); int AS1 = *(header + 1);
int column = *(header + 2); int column = *(header + 2);
int depth = *(header + 3); int depth = *(header + 3);
int lockid = *(header + 4); int lockid = *(header + 4);
int maxid = *(header + 5); int maxid = *(header + 5);
int *pinc = lut + offset; int *pinc = lut + offset;
int offhc = depth; int offhc = depth;
#ifdef DSD #ifdef DSD
// output offset // output offset
int offnc = pid1 * TN; int offnc = pid1 * TN;
int offmc = column * TM; int offmc = column * TM;
int offpc = 0; int offpc = 0;
// dense input offset // dense input offset
int offnb = pid1 * TN; int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc; int offkb __multipleof(8) = *pinc;
int offpb = 0; int offpb = 0;
// sparse input offset // sparse input offset
int offma = 0; int offma = 0;
int offka = 0; int offka = 0;
long offpa __multipleof(8) = *(pinc + 1); long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK; offpa = offpa * BLOCK * BLOCK;
int offha = 0; int offha = 0;
int offhb = depth; int offhb = depth;
#endif #endif
#ifdef DDS #ifdef DDS
// output offset // output offset
int offmc = pid1 * TM; int offmc = pid1 * TM;
int offnc = column * TN; int offnc = column * TN;
int offpc = 0; int offpc = 0;
// dense input offset // dense input offset
int offma = pid1 * TM; int offma = pid1 * TM;
int offka __multipleof(8) = *pinc; int offka __multipleof(8) = *pinc;
int offpa = 0; int offpa = 0;
// sparse input offset // sparse input offset
int offnb = 0; int offnb = 0;
int offkb = 0; int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1); long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK; offpb = offpb * BLOCK * BLOCK;
int offha = depth; int offha = depth;
int offhb = 0; int offhb = 0;
#endif #endif
int ram[TM] = offma + 0 ... TM; int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN; int rbn[TN] = offnb + 0 ... TN;
#endif #endif
// initialize a, b pointers // initialize a, b pointers
int rka[TK] = offka + 0 ... TK; int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK; int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK; TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for (int k = AS1; k > 0; k -= step) {
acc += a @b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch // pre-fetch
#ifdef DDS bool checkak[TM, TK] = k > TK;
bool checkam[TM, TK] = ram[:, newaxis] < DS0; bool checkbk[TK, TN] = k > TK;
#else bool checka[TM, TK] = checkam && checkak;
bool checkam[TM, TK] = AS1 > 0; bool checkb[TK, TN] = checkbk && checkbn;
#endif a = *? (checka)pa;
#ifdef DSD b = *? (checkb)pb;
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0; }
#else TYPE c[TM, TN] = acc;
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */ /* ---------------- */
/* Inner Loop */ /* Epilogue */
/* ---------------- */ /* ---------------- */
// create result tile // initialize c pointers
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
#ifdef SDD #ifdef SDD
int inc_a = TK * STRIDE_AK; bool checkc[TM, TN] = 1;
int inc_b = TK * STRIDE_BK; // rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
int rr_offlutn[TN] = rr_blockidn * 4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else #else
pinc += 2; int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD #ifdef DSD
int inc_b __multipleof(8) = *pinc; bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif #endif
#ifdef DDS #ifdef DDS
int inc_a __multipleof(8) = *pinc; bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif #endif
#endif #endif
pa += inc_a; TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
pb += inc_b; // write-back directly
// pre-fetch if (lockid == 0) {
bool checkak[TM, TK] = k > TK; *? (checkc)pc = c;
bool checkbk[TK, TN] = k > TK; }
bool checka[TM, TK] = checkam && checkak; // accumulate partial result using spin-locks
bool checkb[TK, TN] = checkbk && checkbn; else {
a = *?(checka)pa; int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
b = *?(checkb)pb; int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
} for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
TYPE c[TM, TN] = acc; ;
int count = *pcount;
/* ---------------- */ if (count == 0)
/* Epilogue */ *? (checkc)pc = c;
/* ---------------- */ else
// initialize c pointers *? (checkc)pc = c + *? (checkc)pc;
#ifdef SDD atomic_xchg(pcount, (count + 1) % maxid);
bool checkc[TM, TN] = 1; atomic_xchg(plock, 0);
// rematerialize }
int rr_blockidm[TM] = (0 ... TM) / BLOCK; }
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}

View File

@@ -10,454 +10,416 @@ src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
# MAIN API # # MAIN API #
############## ##############
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
sdd_cache = dict()
dsd_cache = dict()
dds_cache = dict()
locks = dict()
# Given an array sizes representing reduction size for each sdd_cache = dict()
# column of a block-mode matrix multiplication, dsd_cache = dict()
# performs load-balancing to achieve more smaller reductions dds_cache = dict()
# between `seg_size` elements locks = dict()
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
if r >= seg_min or isempty:
segments[current+d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod
def get_locks(size, dev):
if dev not in _matmul.locks or \
size > _matmul.locks[dev].size(0):
_matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
return _matmul.locks[dev]
########################## # Given an array sizes representing reduction size for each
# SPARSE = DENSE x DENSE # # column of a block-mode matrix multiplication,
########################## # performs load-balancing to achieve more smaller reductions
# between `seg_size` elements
@staticmethod
def load_balance(sizes, block):
# segment size
# heuristics taken from OpenAI blocksparse code
# https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
max_size = sizes.max()
min_size = sizes[sizes != 0].min()
#if max_size > min_size * 2.0:
# seg_max = max(triton.cdiv(max_size, 4), min_size*2)
#else:
# seg_max = max_size
seg_max = max_size
seg_min = max(triton.cdiv(seg_max, 4), 4)
# split reduction into segments
div = sizes // seg_max
rem = sizes % seg_max
packs = div + (sizes < seg_min).long() + (rem >= seg_min).long()
width = packs.sum()
segments = torch.empty(width, dtype=sizes.dtype)
column = torch.empty_like(segments)
lockid = torch.zeros_like(segments)
maxid = torch.zeros_like(segments)
nlocks = 0
current = 0
col_idx = 0
for i in range(len(sizes)):
d, r = div[i], rem[i]
isempty = sizes[i] < seg_min
last = current + d + (r >= seg_min) + isempty
# column id
column[current:last] = col_idx
# lock id
if d > 1 or (d == 1 and r >= seg_min):
nlocks += 1
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current + d] = seg_max
if r < seg_min and not isempty:
segments[current + d - 1] += r
if r >= seg_min or isempty:
segments[current + d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
offsets[1:] = torch.cumsum(segments[:-1], dim=0)
return segments, column, lockid, maxid, offsets
@staticmethod @staticmethod
def make_sdd_lut(layout, block, dtype, device): def get_locks(size, dev):
start_width = 128 // block if dev not in _matmul.locks or \
superblocks = libtriton.superblock(layout.type(torch.int32), start_width) size > _matmul.locks[dev].size(0):
luts, widths, packs = [], [], [] _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev)
for size, nnz in superblocks: return _matmul.locks[dev]
width = nnz.shape[0] // (size*size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
b = nnz[:, 3]
lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
luts.append(lut.type(torch.int32).to(device))
widths.append(width)
packs.append(size)
# create locks
return luts, None, widths, packs
@staticmethod ##########################
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, # SPARSE = DENSE x DENSE #
spdims, block, luts, num_locks, widths, packs): ##########################
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
device = a.device
is_16_multiple = AS3 % 16 == 0
is_32_multiple = AS3 % 32 == 0
is_64_multiple = AS3 % 64 == 0
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key] @staticmethod
# create output def make_sdd_lut(layout, block, dtype, device):
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device) start_width = 128 // block
# maximum grid size is 65535 superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
# so operation might be decomposed into multiple luts, widths, packs = [], [], []
# kernel calls for size, nnz in superblocks:
max_width = 49152 width = nnz.shape[0] // (size * size)
for off_width in range(0, width, max_width): h = nnz[:, 0]
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), i = nnz[:, 1]
a.stride(2), b.stride(2), block, j = nnz[:, 2]
a.stride(0), b.stride(0), c.stride(0), b = nnz[:, 3]
a.stride(1), b.stride(1), c.stride(0), lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous()
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock, luts.append(lut.type(torch.int32).to(device))
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) widths.append(width)
# save for backward pass packs.append(size)
return c # create locks
return luts, None, widths, packs
########################## @staticmethod
# DENSE = DENSE x SPARSE # def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
# DENSE = SPARSE x DENSE #
##########################
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
column = _empty.clone()
depth = _empty.clone()
lockid = _empty.clone()
maxid = _empty.clone()
offsets = _empty.clone()
current_offset = 0
current_maxid = 0
for z in range(layout.size(0)):
if trans:
sizes = torch.sum(layout[z, :, :], 1)
else:
sizes = torch.sum(layout[z, :, :], 0)
z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod if trans_c:
def _dds_matmul(a, b, trans_a, trans_b, trans_c, a, b = b, a
spdims, block, lut, num_locks, width, packs): trans_a, trans_b = not trans_b, not trans_a
# shapes / dtypes AS0 = a.size(0)
AS0 = a.size(0) AS1 = a.size(1)
AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2)
AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3)
AS3 = a.size(2 if trans_a else 3) BS0 = b.size(0)
BS0 = spdims[0] BS1 = b.size(1)
BS1 = block * spdims[2 if trans_b else 1] BS2 = b.size(3 if trans_b else 2)
BS2 = block * spdims[1 if trans_b else 2] BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype dtype = a.dtype
# kernel device = a.device
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) is_16_multiple = AS3 % 16 == 0
if key not in _matmul.dds_cache: is_32_multiple = AS3 % 32 == 0
defines = {'TM': 128, is_64_multiple = AS3 % 64 == 0
'TN': block, if not is_16_multiple:
'TK': 16, raise ValueError('Reduction size for SDD must be a multiple of 16')
'BLOCK': block, # create kernel
'TYPE': dtype, total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
'STRIDE_AM': 1 if trans_a else 'lda', c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
'STRIDE_AK': 'lda' if trans_a else 1, for lut, width, pack in zip(luts, widths, packs):
'STRIDE_BN': block if trans_b else 1, num_lock = 1
'STRIDE_BK': 1 if trans_b else block, key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
'STRIDE_CM': '1' if trans_c else 'ldc', if key not in _matmul.sdd_cache:
'STRIDE_CN': 'ldc' if trans_c else '1', defines = {
'NAME': 'dds_kernel', 'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
'DDS': True} 32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
kernel = _matmul.dds_cache[key] 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
# output }
CS0 = AS0 _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object), kernel = _matmul.sdd_cache[key]
'dsd': _dsd_matmul.__get__(object), # create output
'dds': _dds_matmul.__get__(object)} locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
lut.data_ptr(), locks.data_ptr(), num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
@staticmethod ##########################
def forward(ctx, a, b, trans_a, trans_b, trans_c, # DENSE = DENSE x SPARSE #
mode, spdims, block, # DENSE = SPARSE x DENSE #
c_lut, c_num_locks, c_width, c_packs, ##########################
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod # Given a binary layout of 0s and 1s,
def backward(ctx, dc): # Construct look-up table for efficient execution on GPUs
# saved for backward @staticmethod
a, b = ctx.saved_tensors def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
mode = ctx.mode # load-balancing
# gradients w.r.t. a _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
if ctx.needs_input_grad[0]: segments = _empty.clone()
mode_da = mode[1] + mode[0] + mode[2] column = _empty.clone()
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, depth = _empty.clone()
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs) lockid = _empty.clone()
# gradients w.r.t. b maxid = _empty.clone()
if ctx.needs_input_grad[1]: offsets = _empty.clone()
mode_db = mode[2] + mode[1] + mode[0] current_offset = 0
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, current_maxid = 0
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs) for z in range(layout.size(0)):
return da, db, None, None, None,\ if trans:
None, None, None, None,\ sizes = torch.sum(layout[z, :, :], 1)
None, None, None, None, None, None,\ else:
None, None, None, None, None, None,\ sizes = torch.sum(layout[z, :, :], 0)
None, None, None, None, None, None z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block)
z_depth = z * torch.ones_like(z_segments)
z_lockid[z_lockid > 0] += current_maxid
current_maxid = z_lockid.max()
# concatenate depth
segments = torch.cat((segments, z_segments))
column = torch.cat((column, z_column))
depth = torch.cat((depth, z_depth))
maxid = torch.cat((maxid, z_maxid))
offsets = torch.cat((offsets, current_offset + z_offsets))
lockid = torch.cat((lockid, z_lockid))
current_offset += layout[z, :, :].sum()
segments *= step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
idx = transform(nnz[:, 2] * block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
widx = torch.arange(num_blocks)
else:
widx = _empty.clone()
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum)
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx * block * block
wincs[1:] -= widx[:-1] * block * block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div - 1) * step
else:
wincs[:, 1:] = step * block
wincs[:, 0] -= (div - 1) * step * block
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2 * div
segments *= div
# create header
width = column.size(0)
offsets += 6 * width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
num_locks = max(1, lockid.max())
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
AS3 = a.size(2 if trans_a else 3)
BS0 = spdims[0]
BS1 = block * spdims[2 if trans_b else 1]
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel', 'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
AS2 = block * spdims[1 if trans_a else 2]
BS0 = b.size(0)
BS1 = b.size(1)
BS2 = b.size(3 if trans_b else 2)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.da_packs = da_packs
ctx.db_lut = db_lut
ctx.db_num_locks = db_num_locks
ctx.db_width = db_width
ctx.db_packs = db_packs
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None,\
None, None, None, None, None, None
class matmul: class matmul:
def make_lut(self, dtype, device):
def make_lut(self, dtype, device): key = (dtype, device)
key = (dtype, device) if key in self.lut_cache:
if key in self.lut_cache: return self.lut_cache[key]
return self.lut_cache[key] # C look-up table
# C look-up table layout, block = self.layout, self.block
layout, block = self.layout, self.block step = 8 if dtype == torch.float32 else 16
step = 8 if dtype == torch.float32 else 16 if self.mode == 'sdd':
if self.mode == 'sdd': c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) elif self.mode == 'dsd':
elif self.mode == 'dsd': c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device)
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) elif self.mode == 'dds':
elif self.mode == 'dds': c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device)
c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) # DA look-up table
# DA look-up table if self.mode == 'sdd':
if self.mode == 'sdd': da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device)
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) elif self.mode == 'dsd':
elif self.mode == 'dsd': da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) elif self.mode == 'dds':
elif self.mode == 'dds': da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) device)
# DB look-up table # DB look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device)
elif self.mode == 'dds': elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device) db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs,\ da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs) db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key] return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False): def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
if mode not in ['sdd', 'dsd', 'dds']: if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds') raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache # look-up table cache
self.lut_cache = dict() self.lut_cache = dict()
# attributes # attributes
self.trans_a = trans_a self.trans_a = trans_a
self.trans_b = trans_b self.trans_b = trans_b
self.mode = mode self.mode = mode
self.spdims = layout.shape self.spdims = layout.shape
self.block = block self.block = block
self.layout = layout self.layout = layout
# pad shapes of a tensor to make it
# compatible with kernel calls
@staticmethod
def _pad_shape(x, is_sparse):
max_dim = 3 if is_sparse else 4
for i in range(max_dim - x.dim()):
x = x.unsqueeze(0)
return x
def __call__(self, a, b): # pad shapes of a tensor to make it
c_lut, c_num_locks, c_width, c_packs,\ # compatible with kernel calls
da_lut, da_num_locks, da_width, da_packs,\ @staticmethod
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) def _pad_shape(x, is_sparse):
# pad shapes with ones max_dim = 3 if is_sparse else 4
a = matmul._pad_shape(a, self.mode == 'dsd') for i in range(max_dim - x.dim()):
b = matmul._pad_shape(b, self.mode == 'dds') x = x.unsqueeze(0)
# execute return x
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block, def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs, c_lut, c_num_locks, c_width, c_packs,\
da_lut, da_num_locks, da_width, da_packs, da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs) db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)
return c # pad shapes with ones
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
db_width, db_packs)
return c

View File

@@ -1,9 +1,9 @@
#define STM 8 #define STM 8
#define STN 8 #define STN 8
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16), __global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16), TYPE *B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16), TYPE *C __noalias __aligned(16),
float alpha, float alpha,
int M, int M,
int N, int N,
@@ -11,87 +11,88 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int lda __multipleof(LDA_POW2_DIV), int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(LDB_POW2_DIV), int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(LDC_POW2_DIV), int ldc __multipleof(LDC_POW2_DIV),
int* locks) { int *locks) {
// prologue // prologue
int pid = get_program_id(0); int pid = get_program_id(0);
int pidz = get_program_id(2); int pidz = get_program_id(2);
int gridm = (M + TM - 1) / TM; int gridm = (M + TM - 1) / TM;
int gridn = (N + TN - 1) / TN; int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance // swizzle for better L2 performance
int width = STM*gridn; int width = STM * gridn;
int stm = pid / width; int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM); int RSTM = min(gridm - stm * STM, STM);
int stn = (pid % width) / (RSTM*STN); int stn = (pid % width) / (RSTM * STN);
int RSTN = min(gridn - stn*STN, STN); int RSTN = min(gridn - stn * STN, STN);
int laneid = pid % (RSTM * RSTN); int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN; int lanem = laneid / RSTN;
int lanen = laneid % RSTN; int lanen = laneid % RSTN;
int pidm = stm*STM + lanem; int pidm = stm * STM + lanem;
int pidn = stn*STN + lanen; int pidn = stn * STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM; int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN; int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism // split-k for better parrallelism
K = K / TZ; K = K / SPLITK;
int rk[TK] = 0 ... TK; int rk[TK] = 0 ... TK;
// pointers to operands // pointers to operands
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa; TYPE *pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb; TYPE *pb[TK, TN] = B + offb;
// prefetches operands // prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K; bool checka[TM, TK] = rk [newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K; bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0; TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0; TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for (int k = K; k > 0; k -= TK) {
#if (IS_TK_DIV_K==1) #if (IS_TK_DIV_K == 1)
bool checkk[TK] = k > TK; bool checkk[TK] = k > TK;
#else #else
bool checkk[TK] = rk < k - TK; bool checkk[TK] = rk < k - TK;
#endif #endif
bool checka[TM, TK] = checkk[newaxis, :]; bool checka[TM, TK] = checkk [newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis]; bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @ b; acc += a @b;
#if (IS_TK_DIV_K==1) #if (IS_TK_DIV_K == 1)
a = *?(checka)pa; a = *? (checka)pa;
b = *?(checkb)pb; b = *? (checkb)pb;
#else #else
a = checka ? *pa : 0; a = checka ? *pa : 0;
b = checkb ? *pb : 0; b = checkb ? *pb : 0;
#endif #endif
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
} }
acc = acc * alpha; acc = acc * alpha;
TYPE c[TM, TN] = acc; TYPE c[TM, TN] = acc;
// epilogue // epilogue
int rcm[TM] = pidm * TM + 0 ... TM; int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN; int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE* pc[TM, TN] = C + offc; TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N; bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (TZ==1) #if (SPLITK == 1)
*?(checkc) pc = c; *? (checkc)pc = c;
#else #else
// accumulate partial result using spin-locks // accumulate partial result using spin-locks
int *plock = locks + pid; int *plock = locks + pid;
int *pcount = plock + get_num_programs(0); int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
int count = *pcount; ;
if(count == 0) int count = *pcount;
*?(checkc) pc = c; if (count == 0)
else *? (checkc)pc = c;
*?(checkc) pc = c + *?(checkc)pc; else
atomic_xchg(pcount, (count + 1) % TZ); *? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(plock, 0); atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif #endif
} }

View File

@@ -3,29 +3,32 @@ import triton
import os import os
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [ _DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2), ({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2), ({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
] ]
_CONFIGS = _DEFAULT_CONFIGS _CONFIGS = _DEFAULT_CONFIGS
@staticmethod @staticmethod
def largest_pow2_divisor(N): def largest_pow2_divisor(N):
if N % 8 == 0: return 8 if N % 8 == 0:
if N % 4 == 0: return 4 return 8
if N % 2 == 0: return 2 if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1 return 1
_locks = dict() _locks = dict()
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
K, N = b.shape K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device) c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if a.stride(0) > 1 and a.stride(1) > 1:
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash # kernel hash
is_a_row = a.stride(1) == 1 is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1 is_b_row = b.stride(1) == 1
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0 is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels: if key not in _matmul._kernels:
defines = { defines = {
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda', "TYPE": dtype,
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV': "STRIDE_AM": "lda" if is_a_row else "1",
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K': "STRIDE_AK": "1" if is_a_row else "lda",
int(is_tk_div_k) "STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
} }
_matmul._kernels[key] = triton.kernel(_matmul.src, _matmul._kernels[key] = triton.kernel(
device, _matmul.src,
defines=defines, device,
autotune_vals=_matmul._CONFIGS, defines=defines,
autotune_key=['M', 'N', 'K']) autotune_vals=_matmul._CONFIGS,
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key] kernel = _matmul._kernels[key]
# # locks for split-k # # locks for split-k
if device not in _matmul._locks: if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device] locks = _matmul._locks[device]
# enqueue # enqueue
alpha = 1. alpha = 1.0
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] args = [
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ] a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
kernel(*args, grid=grid) kernel(*args, grid=grid)
return c return c

View File

@@ -1,21 +1,33 @@
import torch import torch
def sparsify_tensor(x, mask, block): def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) ret = torch.empty(
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] ret[:, idx, :, :] = x[
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
]
return ret return ret
def mask_tensor(x, mask, block, value=0): def mask_tensor(x, mask, block, value=0):
ret = x.clone() ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret return ret
def allclose(x, y): def allclose(x, y):
assert x.dtype == y.dtype assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype] diff = abs(x - y)
return torch.allclose(x, y, atol=atol, rtol=rtol) x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, flops=0, warmup=10, rep=50): def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
time_ms = start_event.elapsed_time(end_event) / rep time_ms = start_event.elapsed_time(end_event) / rep
return time_ms return time_ms
class Benchmark: class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): def __init__(
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
):
self.x_names = x_names self.x_names = x_names
self.x_vals = x_vals self.x_vals = x_vals
self.y_name = y_name self.y_name = y_name
@@ -44,6 +59,7 @@ class Benchmark:
self.plot_name = plot_name self.plot_name = plot_name
self.args = args self.args = args
class Mark: class Mark:
def __init__(self, fn, benchmarks): def __init__(self, fn, benchmarks):
self.fn = fn self.fn = fn
@@ -53,26 +69,31 @@ class Mark:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines) df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals: for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names} x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals] row = [
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.y_vals
]
df.loc[len(df)] = [x] + row df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name: if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names) xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines) plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel) plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel) plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name) plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear') plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale('log' if bench.loglog else 'linear') plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png')) plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv')) df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot): def run(self, result_path, with_plot):
for bench in self.benchmarks: for bench in self.benchmarks:
self._run(bench, result_path, with_plot) self._run(bench, result_path, with_plot)
def perf_report(benchmarks): def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper

View File

@@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkb[TK, TN] = rk[:, newaxis] < K; bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0; TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0; TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK; bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK; bool checkb[TK, TN] = k > TK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
TYPE anext[TM, TK] = *?(checka)pa;
TYPE bnext[TK, TN] = *?(checkb)pb;
acc += a @ b;
a = anext;
b = bnext;
// __debug_barrier();
} }
acc = acc * alpha; acc = acc * alpha;
TYPE c[TM, TN] = acc; TYPE c[TM, TN] = acc;
@@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
opt.defines["TYPE"] = ty; opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128"; opt.defines["TM"] = "128";
opt.defines["TN"] = "128"; opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ; opt.defines["TK"] = "64" ;
opt.defines["TZ"] = "1"; opt.defines["TZ"] = "1";
opt.num_warps = 4; opt.num_warps = 4;
// arguments // arguments