[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,6 +2,9 @@
#define TDL_INCLUDE_CODEGEN_BARRIERS_H #define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector> #include <vector>
#include <map>
#include <list>
#include <set>
namespace triton { namespace triton {
@@ -9,6 +12,7 @@ namespace ir {
class module; class module;
class basic_block; class basic_block;
class instruction; class instruction;
class masked_load_async_inst;
class value; class value;
class builder; class builder;
} }
@@ -29,18 +33,15 @@ namespace transform{
class membar { class membar {
private: private:
typedef std::pair<unsigned, unsigned> interval_t; typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t; typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private: private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals); bool intersect(const val_set_t &X, const val_set_t &Y);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder); int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect(const interval_vec_t &X, interval_t x); val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y); void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
void add_reference(ir::value *v, interval_vec_t &res); std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
public: public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -16,6 +16,10 @@ namespace ir {
} }
namespace codegen{ namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{ namespace transform{
class peephole { class peephole {
@@ -33,11 +37,12 @@ private:
private: private:
public: public:
peephole(target* tgt): tgt_(tgt) {} peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
target* tgt_; target* tgt_;
analysis::layouts* layouts_;
}; };

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
void run(ir::module &module);
private:
bool has_copy_async_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend); static driver::stream* create(backend_t backend);
// methods // methods
virtual void synchronize() = 0; virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0; virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers // template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public: public:
host_stream(); host_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership); cu_stream(CUstream str, bool take_ownership);
cu_stream(); cu_stream();
void synchronize(); void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size); void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
}; };

View File

@@ -35,6 +35,7 @@ public:
basic_block* get_insert_block() { return block_; } basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}
// Constants // Constants
value *get_int1(bool val);
value *get_int32(int32_t val); value *get_int32(int32_t val);
value *get_int64(int64_t val); value *get_int64(int64_t val);
// Types // Types
@@ -149,7 +150,7 @@ public:
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = ""); value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = ""); value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = ""); value *create_barrier(const std::string &name = "");
value *create_async_wait(); value *create_async_wait(int N);
private: private:
context &ctx_; context &ctx_;

View File

@@ -92,6 +92,7 @@ private:
public: public:
void set_incoming_value(unsigned i, value *v); void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block); void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); } value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); } unsigned get_num_incoming() { return get_num_operands(); }
@@ -803,14 +804,18 @@ public:
class async_wait_inst: public instruction{ class async_wait_inst: public instruction{
private: private:
async_wait_inst(context &ctx, const std::string &name, instruction *next); async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; } std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst) _TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst) _TRITON_DEFINE_ACCEPT(async_wait_inst)
public: public:
static async_wait_inst* create(context &ctx, const std::string &name = "", static async_wait_inst* create(context &ctx, int N,
instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
private:
int N_;
}; };
// On NVIDIA, implementation is such that // On NVIDIA, implementation is such that

View File

@@ -98,6 +98,8 @@ private:
std::shared_ptr<ir::module> ir_; std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_; std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_; std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
}; };
class function { class function {

View File

@@ -30,11 +30,8 @@ private:
high_resolution_clock::time_point _start; high_resolution_clock::time_point _start;
}; };
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false) inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{ {
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr; timer tmr;
std::vector<size_t> times; std::vector<size_t> times;
double total_time = 0; double total_time = 0;

View File

@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
if(rhs_cst_info[d].num_cst) if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d]; rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue); result[d] = std::max(lvalue, rvalue);
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
} }
return add_to_cache(x, result, max_contiguous_); return add_to_cache(x, result, max_contiguous_);
} }
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) { // ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v)) // if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0] // std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// }); // });
} }

View File

@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; // std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl; // std::cout << order_[0] << " " << order_[1] << std::endl;
} }
if(is_recoalesce){
if(ptr.size() > 0){
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// if(order_[0] == 0)
// exit(1);
}
}
// std::cout << "---" << std::endl;
} }
int data_layout::find_axis(int to_find) const { int data_layout::find_axis(int to_find) const {
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr; ir::value *ptr = nullptr;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::io_inst*>(usr)) if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
ptr = st->get_pointer_operand(); if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
unsigned i = order_[0]; unsigned i = order_[0];
int contiguous = 1; int contiguous = 1;
if(ptr){ if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits); contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
} }
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));

View File

@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
} }
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned vector = 1; unsigned in_vec = 1;
ir::value *ptrs = x->get_pointer_operand(); ir::value *arg = x->get_pointer_operand();
ir::value *msks = x->get_mask_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline(); analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order(); auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order(); auto in_order = in_layout->get_order();
// tiles // tiles
if(out_order == in_order) if(out_order == in_order)
vector = in_layout->nts(in_order[0]); in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
// //
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; int per_phase = swizzle_->get_per_phase(out_layout);
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1); int max_phase = swizzle_->get_max_phase(out_layout);
Value *max_phase = i32(8 / num_per_phase);
// //
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes(); auto shapes = x->get_type()->get_tile_shapes();
// BasicBlock* CurrBB = builder_->GetInsertBlock();
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1); std::map<std::pair<int, int>, Value*> tmp;
std::vector<Value*> shared; std::vector<std::pair<Value*, int>> shared;
for(size_t i = 0; i < n_shared; i++){ for(int i = 0; i < idxs_.at(arg).size(); i++){
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld); unsigned id = i / min_vec;
// phase
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
phase = urem(phase, max_phase);
// off
Value* off_0 = idx[in_order[0]];
off_0 = udiv(off_0, i32(vector));
off_0 = xor_(off_0, phase);
off_0 = mul(off_0 , i32(vector));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off = add(off_0, off_1);
//
shared.push_back(gep(shmems_[x], {off}));
}
//
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
auto idx = idxs_[ptrs][i];
// input ptr info // input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]); int id_0 = id % (in_ld/min_vec);
Value *in_base = in_gep->getPointerOperand(); int id_1 = id / (in_ld/min_vec);
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector; int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
Value* out_base = shared[(i / per_thread_ld) % n_shared]; int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]); int off = (off_1*shapes[in_order[0]] + off_0);
int out_off_1 = i % per_thread_ld; std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2; if(tmp.find(key) == tmp.end()){
// asm if(CurrBB != FirstBB)
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false); builder_->SetInsertPoint(FirstBB->getTerminator());
std::string mod = (vector*2 == 16) ? ".cg" : ".ca"; indices_t idx = idxs_.at(arg).at(key.first*in_ld);
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";"; Value* phase = udiv(idx[in_order[1]], i32(per_phase));
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true); phase = urem(phase, i32(max_phase));
call(iasm, {vals_[msks][idx], out_base, in_base}); Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
} }
shared.push_back({tmp[key], off});
}
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
} }
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes(); auto shapes = cts->get_type()->get_tile_shapes();
// default implementation // store to shared
Value *current = nullptr; Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs; std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){ for(int i = 0; i < idxs_.at(arg).size(); i++){
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier(); add_barrier();
} }
void generator::visit_async_wait_inst(ir::async_wait_inst*) { void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_all;"; std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm); call(iasm);
add_barrier();
} }
void generator::visit_make_range_dyn(ir::make_range_dyn* x) { void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(unsigned alloc_size = alloc_->allocated_size()){ if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_); Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_); Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4); ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3); Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array = GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage, new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty); shmem_ = bit_cast(sh_mem_array, ptr_ty);
} }

View File

@@ -15,114 +15,105 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
auto *i = dynamic_cast<ir::instruction*>(v); if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
if(!i)
return;
if(!i->get_type()->is_tile_ty())
return;
analysis::shared_layout* layout = layouts_->get(v)->to_shared(); analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if(!layout) analysis::double_buffer_info_t* info = layout->get_double_buffer();
return; if(info)
if(alloc_->has_offset(layout)){ return group_of(info->first, async_write);
unsigned offset = alloc_->offset(layout); std::vector<int> groups(phi->get_num_operands());
res.push_back(interval_t(offset, offset + layout->get_size())); std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
} return *std::max_element(groups.begin(), groups.end());
}
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
assert(inc_val);
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
}
}
} }
else{ else{
builder.set_insert_point(instr); auto it = std::find(async_write.begin(), async_write.end(), v);
builder.create_barrier(); return std::distance(async_write.begin(), it);
} }
} }
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result; membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
for(auto x: intervals) val_set_t ret;
for(interval_t i: x) for(ir::value* a: as){
result.push_back(i); if(!a->get_type()->is_tile_ty())
return result; continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
continue;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
continue;
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
ret.insert(b);
}
}
return ret;
} }
std::pair<membar::interval_vec_t, void membar::transfer(ir::basic_block *block,
membar::interval_vec_t> membar::transfer(ir::basic_block *block, val_vec_t& async_write,
const interval_vec_t &written_to, val_set_t& sync_write,
const interval_vec_t &read_from, val_set_t& sync_read,
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
std::set<ir::value*>& safe_war, std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) { bool& inserted, ir::builder& builder) {
ir::basic_block::inst_list_t instructions = block->get_inst_list(); ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){ for(ir::instruction *i: instructions){
interval_vec_t read, written; if(dynamic_cast<ir::phi_node*>(i))
get_read_intervals(i, read); continue;
get_written_intervals(i, written); if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
if(written.size()) dynamic_cast<ir::masked_load_async_inst*>(i)){
to_sync.push_back(i); async_write.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
} }
// record hazards if(dynamic_cast<ir::copy_to_shared_inst*>(i))
if(read_after_write || write_after_read) { sync_write.insert(i);
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);}; ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);}; ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async); // Get shared memory reads
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared); std::set<ir::value*> read;
insert_loc.insert({i, {copy_async_wait, barrier}}); std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
new_written_to.clear(); [&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
new_read_from.clear(); // RAW (async)
to_sync.clear(); val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
} }
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
} }
return std::make_pair(new_written_to, new_read_from); // RAW, WAR
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
} }
void membar::run(ir::module &mod) { void membar::run(ir::module &mod) {
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to; std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, interval_vec_t> read_from; std::map<ir::basic_block*, val_set_t> sync_writes;
std::vector<ir::instruction*> to_sync; std::map<ir::basic_block*, val_set_t> sync_reads;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs; std::list<ir::value *> pipelined;
size_t n_inserted_im1 = 0; bool inserted;
bool done = false;
do{ do{
inserted = false;
// find barrier location // find barrier location
for(ir::basic_block *block: rpo){ for(ir::basic_block *block: rpo){
// written to // join inputs
std::vector<interval_vec_t> pred_written_to; val_vec_t async_write;
for(ir::basic_block* pred: block->get_predecessors()) val_set_t sync_write;
pred_written_to.push_back(written_to[pred]); val_set_t sync_read;
// read from val_set_t tmp;
std::vector<interval_vec_t> pred_read_from; for(ir::basic_block* pred: block->get_predecessors()){
for(ir::basic_block* pred: block->get_predecessors()) for(ir::value* v: async_writes[pred])
pred_read_from.push_back(read_from[pred]); if(tmp.insert(v).second)
// apply transfer function async_write.push_back(v);
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync); sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
written_to[block] = result.first; sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
read_from[block] = result.second;
} }
size_t n_inserted_i = insert_locs.size(); transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
done = (n_inserted_im1 == n_inserted_i); async_writes[block] = async_write;
n_inserted_im1 = n_inserted_i; sync_writes[block] = sync_write;
}while(!done); sync_reads[block] = sync_read;
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
} }
}while(inserted);
} }
} }

View File

@@ -1,7 +1,9 @@
#include <algorithm> #include <algorithm>
#include <iostream>
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton { namespace triton {
namespace codegen{ namespace codegen{
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
ir::value *ptr = ld->get_pointer_operand(); ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand(); ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand(); ir::value *val = ld->get_false_value_operand();
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load); copy_to_shared->replace_all_uses_with(new_load);
return true; return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
} }
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
bool was_modified = false; bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder); was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder); // was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80) if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder); was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified) if(was_modified)
seen.insert(i); seen.insert(i);
} }

View File

@@ -0,0 +1,116 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
inline bool is_cst(ir::value *x) { inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x)) if(dynamic_cast<ir::constant*>(x))
return true; return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x)) if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0)); return is_cst(v->get_operand(0));
return false; return false;

View File

@@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) { cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name); dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED); dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
// properties
int shared_total, shared_optin, shared_static;
int n_spills, n_reg;
CUdevice dev;
dispatch::cuCtxGetDevice(&dev);
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
} }
} }

View File

@@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
void cu_module::init_from_ptx(const std::string& ptx) { void cu_module::init_from_ptx(const std::string& ptx) {
// JIT compile source-code // JIT compile source-code
// std::cout << ptx << std::endl;
try{ try{
// // compile ptx with ptxas // // compile ptx with ptxas

View File

@@ -76,7 +76,7 @@ void host_stream::synchronize() {
hst_->args.clear(); hst_->args.clear();
} }
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) { void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
auto hst = kernel->module()->hst(); auto hst = kernel->module()->hst();
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
char* params = new char[args_size]; char* params = new char[args_size];
@@ -113,13 +113,13 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_); dispatch::cuStreamSynchronize(*cu_);
} }
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) { void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
void *config[] = { void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END CU_LAUNCH_PARAM_END
}; };
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config); dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
} }
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
// convenience functions // convenience functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val) value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);} { return constant_int::get(type::get_int32_ty(ctx_), val);}
@@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name)); return insert(barrier_inst::create(ctx_, name));
} }
value *builder::create_async_wait() { value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_)); return insert(async_wait_inst::create(ctx_, N));
} }
} }

View File

@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
blocks_.reserve(num_reserved); blocks_.reserve(num_reserved);
} }
value* phi_node::get_value_for_block(basic_block * block) {
auto it = std::find(blocks_.begin(), blocks_.end(), block);
size_t n = std::distance(blocks_.begin(), it);
return get_incoming_value(n);
}
// Set incoming value // Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){ void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!"); assert(v && "PHI node got a null value!");
@@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next); return new barrier_inst(ctx, name, next);
} }
async_wait_inst::async_wait_inst(context &ctx, const std::string &name, async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
instruction *next) : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) { async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next); return new async_wait_inst(ctx, N, name, next);
} }

View File

@@ -15,10 +15,10 @@
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h" #include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h" #include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/runtime/function.h" #include "triton/runtime/function.h"
#include "triton/lang/cpp.h" #include "triton/lang/cpp.h"
#include "triton/lang/parser.h" #include "triton/lang/parser.h"
@@ -149,6 +149,7 @@ void kernel::init_ker(){
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async); codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate; codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts); codegen::analysis::liveness liveness(&layouts);
@@ -156,19 +157,24 @@ void kernel::init_ker(){
codegen::analysis::allocation allocation(&liveness); codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get()); codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate; codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes // run passes
dce.run(*ir_); dce.run(*ir_);
pipeline.run(*ir_);
dce.run(*ir_);
disassociate.run(*ir_); disassociate.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_); peephole.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_);
if(target->is_gpu()) if(target->is_gpu())
cts.run(*ir_); cts.run(*ir_);
align.run(*ir_);
axes.run(*ir_); axes.run(*ir_);
layouts.run(*ir_); layouts.run(*ir_);
coalesce.run(*ir_); coalesce.run(*ir_);
@@ -179,6 +185,11 @@ void kernel::init_ker(){
reassociate.run(*ir_); reassociate.run(*ir_);
cts.run(*ir_); cts.run(*ir_);
} }
dce.run(*ir_);
// ir::print(*ir_, std::cout);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_); peephole.run(*ir_);
dce.run(*ir_); dce.run(*ir_);
align.run(*ir_); align.run(*ir_);
@@ -187,8 +198,9 @@ void kernel::init_ker(){
swizzle.run(*ir_); swizzle.run(*ir_);
liveness.run(*ir_); liveness.run(*ir_);
allocation.run(*ir_); allocation.run(*ir_);
if(allocation.allocated_size() > dev_->max_shared_memory()) shared_mem_ = allocation.allocated_size();
throw exception::out_of_shared_memory(); // if(allocation.allocated_size() > dev_->max_shared_memory())
// throw exception::out_of_shared_memory();
barriers.run(*ir_); barriers.run(*ir_);
isel.visit(*ir_, *llvm); isel.visit(*ir_, *llvm);
//if(res->spilled() > 256) //if(res->spilled() > 256)
@@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++) for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1; grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue // enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size); stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
} }
std::string kernel::get_asm(asm_mode_t mode) { std::string kernel::get_asm(asm_mode_t mode) {
@@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
while(grid.size() < 3) while(grid.size() < 3)
grid.push_back(1); grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, true); stream, 5, 20);
ret = (ts < best_ts) ? current : ret; ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts); best_ts = std::min(ts, best_ts);
} }

View File

@@ -2,58 +2,74 @@ import triton
import torch import torch
# square benchmarks # square benchmarks
nt = {False: 'n', True: 't'} nt = {False: "n", True: "t"}
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names=["M", "N", "K"],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[512 * i for i in range(1, 16)],
y_name = 'provider', y_name="provider",
y_vals = ['torch', 'triton', 'cutlass'], y_vals=["torch", "triton", "cutlass"],
y_lines = ['Torch', 'Triton', 'CUTLASS'], y_lines=["Torch", "Triton", "CUTLASS"],
ylabel = 'TFLOPS', ylabel="TFLOPS",
loglog=False, loglog=False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}', plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args = {'AT': False, 'BT': False, 'dtype': torch.float16} args={"AT": AT, "BT": BT, "dtype": torch.float16},
)\ ) for AT in [False, True] for BT in [False, True]
for AT in [False, True] for BT in [False, True]
] ]
@triton.testing.perf_report(square_confs) @triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
if AT: a = a.t() b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if BT: b = b.t() if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K num_flops = 2 * M * N * K
if provider == 'torch': if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9 torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops return torch_tflops
if provider == 'triton': if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9 triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess import subprocess
import tempfile import tempfile
import pandas as pd import pandas as pd
# run program specified by CUTLASS_PROFILER env variable # run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row' layout_a = "column" if AT else "row"
layout_b = 'column' if BT else 'row' layout_b = "column" if BT else "row"
# create temporary file name # create temporary file name
fd, fname = tempfile.mkstemp() fd, fname = tempfile.mkstemp()
# run program and gets its output # run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ cmd = [
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \ os.environ["CUTLASS_PROFILER"],
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false'] f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--verbose=false",
]
# run cmd # run cmd
subprocess.run(cmd, stdout=subprocess.PIPE) subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output # read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv') df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c['GFLOPs']) / 1e3 cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops return cutlass_tflops
return None return None
if __name__ == '__main__': if __name__ == "__main__":
bench_op.run() bench_op.run()

View File

@@ -15,21 +15,21 @@ import distutils.spawn
import torch import torch
def find_llvm(): def find_llvm():
versions = ['-10', '-10.0', ''] versions = ["-10", "-10.0", ""]
supported = ['llvm-config{v}'.format(v=v) for v in versions] supported = ["llvm-config{v}".format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None] paths = [p for p in paths if p is not None]
if paths: if paths:
return paths[0] return paths[0]
config = distutils.spawn.find_executable('llvm-config') config = distutils.spawn.find_executable("llvm-config")
instructions = 'Please install llvm-10-dev' instructions = "Please install llvm-10-dev"
if config is None: if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions) raise RuntimeError("Could not find llvm-config. " + instructions)
version = os.popen('{config} --version'.format(config=config)).read() version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions) raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension): class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''): def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[]) Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir) self.sourcedir = os.path.abspath(sourcedir)
self.path = path self.path = path
@@ -37,14 +37,14 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext): class CMakeBuild(build_ext):
def run(self): def run(self):
try: try:
out = subprocess.check_output(['cmake', '--version']) out = subprocess.check_output(["cmake", "--version"])
except OSError: except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " + raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions)) ", ".join(e.name for e in self.extensions))
if platform.system() == "Windows": if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
if cmake_version < '3.1.0': if cmake_version < "3.1.0":
raise RuntimeError("CMake >= 3.1.0 is required on Windows") raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions: for ext in self.extensions:
@@ -52,69 +52,69 @@ class CMakeBuild(build_ext):
def build_extension(self, ext): def build_extension(self, ext):
# self.debug = True # self.debug = True
self.debug = False
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories # python directories
python_include_dirs = distutils.sysconfig.get_python_inc() python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
torch_include_dirs = include_paths(True) torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True) torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = [ cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
'-DBUILD_TUTORIALS=OFF', "-DBUILD_TUTORIALS=OFF",
'-DBUILD_PYTHON_MODULE=ON', "-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), "-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), "-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi, "-DTORCH_CXX11_ABI=" + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', "-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
'-DLLVM_CONFIG=' + find_llvm() "-DLLVM_CONFIG=" + find_llvm(),
] ]
# configuration # configuration
cfg = 'Debug' if self.debug else 'Release' cfg = "Debug" if self.debug else "Release"
cfg = 'Release' build_args = ["--config", cfg]
build_args = ['--config', cfg]
if platform.system() == "Windows": if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)] cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32: if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64'] cmake_args += ["-A", "x64"]
build_args += ['--', '/m'] build_args += ["--", "/m"]
else: else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ['--', '-j4'] build_args += ["--", "-j4"]
env = os.environ.copy() env = os.environ.copy()
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup( setup(
name='triton', name="triton",
version='1.0.0', version="1.0.0",
author='Philippe Tillet', author="Philippe Tillet",
author_email='phil@openai.com', author_email="phil@openai.com",
description='A language and compiler for custom Deep Learning operations', description="A language and compiler for custom Deep Learning operations",
long_description='', long_description="",
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'], packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=['numpy', 'torch'], install_requires=["numpy", "torch"],
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']}, package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True, include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')], ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={'build_ext': CMakeBuild}, cmdclass={"build_ext": CMakeBuild},
zip_safe=False, zip_safe=False,
# for PyPI # for PyPI
keywords=['Compiler', 'Deep Learning'], keywords=["Compiler", "Deep Learning"],
url='https://github.com/ptillet/triton/', url="https://github.com/ptillet/triton/",
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers "Intended Audience :: Developers", # Define that your audience are developers
'Topic :: Software Development :: Build Tools', "Topic :: Software Development :: Build Tools",
'License :: OSI Approved :: MIT License', # Again, pick a license "License :: OSI Approved :: MIT License", # Again, pick a license
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
], ],
) )

View File

@@ -2,29 +2,17 @@ import torch
import triton import triton
import pytest import pytest
@pytest.mark.parametrize( @pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK", "MODE, TRANS_A, TRANS_B, BLOCK",
[ [(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
(mode, at, bt, block) for block in [16, 32, 64]],
for mode in ["sdd", "dsd", "dds"]
for at in [False, True]
for bt in [False, True]
for block in [16, 32, 64]
],
) )
def test_matmul( def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
# create inputs # create inputs
a = torch.randn( a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda" b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
shape = { shape = {
"sdd": (M, N), "sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]), "dsd": (a.shape[2], a.shape[3]),
@@ -32,9 +20,7 @@ def test_matmul(
}[MODE] }[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result # triton result
op = triton.ops.blocksparse.matmul( op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb) rc = op(ra, rb)
@@ -49,7 +35,6 @@ def test_matmul(
# compare # compare
assert triton.testing.allclose(rc, tc) assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"BLOCK, WIDTH", "BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]], [(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# create inputs # create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint( at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda" kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
)
kp_mask = torch.randint(
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
kp_mask[kp_mask == 1.0] = float("-inf") kp_mask[kp_mask == 1.0] = float("-inf")
# triton result # triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK) op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# compare # compare
assert triton.testing.allclose(ry, ty) assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd( def test_attention_fwd_bwd(
input_scale=1.0, input_scale=1.0,
tol=2e-2, tol=2e-2,
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
# inputs # inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64) qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [ qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
.to(dtype)
.cuda()
for _ in range(3)
] ]
attn_mask = torch.tril( attn_mask = torch.tril(
torch.ones( torch.ones(
@@ -129,9 +106,7 @@ def test_attention_fwd_bwd(
query.retain_grad() query.retain_grad()
key.retain_grad() key.retain_grad()
value.retain_grad() value.retain_grad()
attn_out = triton_attention( attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
# ad hoc loss # ad hoc loss
loss = (attn_out**2).mean() loss = (attn_out**2).mean()
loss.backward() loss.backward()
@@ -153,12 +128,11 @@ def test_attention_fwd_bwd(
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison # comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol) torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads): for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention( def triton_attention(
layout, layout,
block: int, block: int,
@@ -168,12 +142,8 @@ def triton_attention(
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
): ):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
layout, block, "sdd", trans_a=False, trans_b=True sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_softmax = triton.ops.blocksparse.softmax( sparse_softmax = triton.ops.blocksparse.softmax(
layout, layout,
block, block,

View File

@@ -4,7 +4,7 @@ import triton
import torch import torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", "TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[ itertools.chain(*[
[ [
# 1 warp # 1 warp
@@ -17,14 +17,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp # # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp # # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -40,22 +40,26 @@ import torch
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input # variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE) (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True] ] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
])) ]),
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): )
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE] def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0) torch.manual_seed(0)
triton.ops._matmul._kernels = dict() triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
if M is None: M = TM if M is None:
if N is None: N = TN M = TM
if K is None: K = TK * TZ if N is None:
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 N = TN
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 if K is None:
K = TK * SPLITK
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a a = a.t() if AT else a
b = b.t() if BT else b b = b.t() if BT else b
th_c = torch.matmul(a, b) th_c = torch.matmul(a, b)

View File

@@ -186,7 +186,8 @@
else { else {
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1; int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks; int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount; int count = *pcount;
if (count == 0) if (count == 0)
*? (checkc)pc = c; *? (checkc)pc = c;

View File

@@ -98,8 +98,7 @@ class _matmul(torch.autograd.Function):
return luts, None, widths, packs return luts, None, widths, packs
@staticmethod @staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
spdims, block, luts, num_locks, widths, packs):
if trans_c: if trans_c:
a, b = b, a a, b = b, a
@@ -126,17 +125,12 @@ class _matmul(torch.autograd.Function):
num_lock = 1 num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache: if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack, defines = {
'TMN': block*block*pack*pack, 'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
'BLOCK': block, 32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
'TK': 32, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
'TYPE': dtype, 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
'STRIDE_AM': '1' if trans_a else 'lda', }
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key] kernel = _matmul.sdd_cache[key]
@@ -147,11 +141,9 @@ class _matmul(torch.autograd.Function):
# kernel calls # kernel calls
max_width = 49152 max_width = 49152
for off_width in range(0, width, max_width): for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
a.stride(2), b.stride(2), block, b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
a.stride(0), b.stride(0), c.stride(0), lut.data_ptr(), locks.data_ptr(), num_lock,
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass # save for backward pass
return c return c
@@ -252,8 +244,7 @@ class _matmul(torch.autograd.Function):
return lut, num_locks, width, None return lut, num_locks, width, None
@staticmethod @staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c, def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes # shapes / dtypes
AS0 = a.size(0) AS0 = a.size(0)
AS1 = a.size(1) AS1 = a.size(1)
@@ -266,19 +257,12 @@ class _matmul(torch.autograd.Function):
# kernel # kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache: if key not in _matmul.dds_cache:
defines = {'TM': 128, defines = {
'TN': block, 'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
'TK': 16, 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
'BLOCK': block, 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
'TYPE': dtype, 'NAME': 'dds_kernel', 'DDS': True
'STRIDE_AM': 1 if trans_a else 'lda', }
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key] kernel = _matmul.dds_cache[key]
# output # output
@@ -288,17 +272,13 @@ class _matmul(torch.autograd.Function):
CS3 = AS2 if trans_c else BS2 CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
a.stride(2), block, c.stride(2), c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
a.stride(0), b.stride(0), c.stride(0), num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c return c
@staticmethod @staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
spdims, block, lut, num_locks, width, packs):
# shapes / dtypes # shapes / dtypes
AS0 = spdims[0] AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1] AS1 = block * spdims[2 if trans_a else 1]
@@ -311,19 +291,12 @@ class _matmul(torch.autograd.Function):
# kernel # kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache: if key not in _matmul.dsd_cache:
defines = {'TM': block, defines = {
'TN': 128, 'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
'TK': 16, 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
'BLOCK': block, '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
'TYPE': dtype, 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
'STRIDE_AM': 1 if trans_a else block, }
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key] kernel = _matmul.dsd_cache[key]
# output # output
@@ -333,26 +306,17 @@ class _matmul(torch.autograd.Function):
CS3 = AS1 if trans_c else BS3 CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
block, b.stride(2), c.stride(2), c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
a.stride(0), b.stride(0), c.stride(0), num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c return c
fn = {'sdd': _sdd_matmul.__get__(object), fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
@staticmethod @staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c, def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
mode, spdims, block, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
c_lut, c_num_locks, c_width, c_packs, c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
# save for backward # save for backward
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks ctx.da_num_locks = da_num_locks
@@ -378,13 +342,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a # gradients w.r.t. a
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2] mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs) ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b # gradients w.r.t. b
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0] mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs) ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\ return da, db, None, None, None,\
None, None, None, None,\ None, None, None, None,\
None, None, None, None, None, None,\ None, None, None, None, None, None,\
@@ -392,7 +356,6 @@ class _matmul(torch.autograd.Function):
None, None, None, None, None, None None, None, None, None, None, None
class matmul: class matmul:
def make_lut(self, dtype, device): def make_lut(self, dtype, device):
key = (dtype, device) key = (dtype, device)
if key in self.lut_cache: if key in self.lut_cache:
@@ -412,7 +375,8 @@ class matmul:
elif self.mode == 'dsd': elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds': elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
device)
# DB look-up table # DB look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
@@ -455,9 +419,7 @@ class matmul:
a = matmul._pad_shape(a, self.mode == 'dsd') a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds') b = matmul._pad_shape(b, self.mode == 'dds')
# execute # execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
self.mode, self.spdims, self.block, c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
c_lut, c_num_locks, c_width, c_packs, db_width, db_packs)
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
return c return c

View File

@@ -33,7 +33,7 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int rn[TN] = pidn * TN + 0 ... TN; int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism // split-k for better parrallelism
K = K / TZ; K = K / SPLITK;
int rk[TK] = 0 ... TK; int rk[TK] = 0 ... TK;
// pointers to operands // pointers to operands
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
@@ -79,19 +79,20 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :]; int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc; TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N; bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (TZ==1) #if (SPLITK == 1)
*? (checkc)pc = c; *? (checkc)pc = c;
#else #else
// accumulate partial result using spin-locks // accumulate partial result using spin-locks
int *plock = locks + pid; int *plock = locks + pid;
int *pcount = plock + get_num_programs(0); int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount; int count = *pcount;
if (count == 0) if (count == 0)
*? (checkc)pc = c; *? (checkc)pc = c;
else else
*? (checkc)pc = c + *? (checkc)pc; *? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ); atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0); atomic_xchg(plock, 0);
#endif #endif
} }

View File

@@ -3,29 +3,32 @@ import triton
import os import os
class _matmul(torch.autograd.Function): class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [ _DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4), ({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2), ({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2), ({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
] ]
_CONFIGS = _DEFAULT_CONFIGS _CONFIGS = _DEFAULT_CONFIGS
@staticmethod @staticmethod
def largest_pow2_divisor(N): def largest_pow2_divisor(N):
if N % 8 == 0: return 8 if N % 8 == 0:
if N % 4 == 0: return 4 return 8
if N % 2 == 0: return 2 if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1 return 1
_locks = dict() _locks = dict()
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
K, N = b.shape K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device) c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if a.stride(0) > 1 and a.stride(1) > 1:
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash # kernel hash
is_a_row = a.stride(1) == 1 is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1 is_b_row = b.stride(1) == 1
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0 is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels: if key not in _matmul._kernels:
defines = { defines = {
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda', "TYPE": dtype,
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV': "STRIDE_AM": "lda" if is_a_row else "1",
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K': "STRIDE_AK": "1" if is_a_row else "lda",
int(is_tk_div_k) "STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
} }
_matmul._kernels[key] = triton.kernel(_matmul.src, _matmul._kernels[key] = triton.kernel(
_matmul.src,
device, device,
defines=defines, defines=defines,
autotune_vals=_matmul._CONFIGS, autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K']) autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key] kernel = _matmul._kernels[key]
# # locks for split-k # # locks for split-k
if device not in _matmul._locks: if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device] locks = _matmul._locks[device]
# enqueue # enqueue
alpha = 1. alpha = 1.0
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] args = [
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ] a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
kernel(*args, grid=grid) kernel(*args, grid=grid)
return c return c

View File

@@ -1,21 +1,33 @@
import torch import torch
def sparsify_tensor(x, mask, block): def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) ret = torch.empty(
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] ret[:, idx, :, :] = x[
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
]
return ret return ret
def mask_tensor(x, mask, block, value=0): def mask_tensor(x, mask, block, value=0):
ret = x.clone() ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret return ret
def allclose(x, y): def allclose(x, y):
assert x.dtype == y.dtype assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype] diff = abs(x - y)
return torch.allclose(x, y, atol=atol, rtol=rtol) x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, flops=0, warmup=10, rep=50): def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
time_ms = start_event.elapsed_time(end_event) / rep time_ms = start_event.elapsed_time(end_event) / rep
return time_ms return time_ms
class Benchmark: class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): def __init__(
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
):
self.x_names = x_names self.x_names = x_names
self.x_vals = x_vals self.x_vals = x_vals
self.y_name = y_name self.y_name = y_name
@@ -44,6 +59,7 @@ class Benchmark:
self.plot_name = plot_name self.plot_name = plot_name
self.args = args self.args = args
class Mark: class Mark:
def __init__(self, fn, benchmarks): def __init__(self, fn, benchmarks):
self.fn = fn self.fn = fn
@@ -53,26 +69,31 @@ class Mark:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import os import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines) df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals: for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names} x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals] row = [
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.y_vals
]
df.loc[len(df)] = [x] + row df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name: if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names) xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines) plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel) plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel) plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name) plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear') plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale('log' if bench.loglog else 'linear') plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png')) plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv')) df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot): def run(self, result_path, with_plot):
for bench in self.benchmarks: for bench in self.benchmarks:
self._run(bench, result_path, with_plot) self._run(bench, result_path, with_plot)
def perf_report(benchmarks): def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks) wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper return wrapper

View File

@@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkb[TK, TN] = rk[:, newaxis] < K; bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0; TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0; TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop // reduction loop
float acc[TM, TN] = 0; float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){ for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK; bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK; bool checkb[TK, TN] = k > TK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
pa += TK * STRIDE_AK; pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK; pb += TK * STRIDE_BK;
TYPE anext[TM, TK] = *?(checka)pa;
TYPE bnext[TK, TN] = *?(checkb)pb;
acc += a @ b;
a = anext;
b = bnext;
// __debug_barrier();
} }
acc = acc * alpha; acc = acc * alpha;
TYPE c[TM, TN] = acc; TYPE c[TM, TN] = acc;
@@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
opt.defines["TYPE"] = ty; opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128"; opt.defines["TM"] = "128";
opt.defines["TN"] = "128"; opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ; opt.defines["TK"] = "64" ;
opt.defines["TZ"] = "1"; opt.defines["TZ"] = "1";
opt.num_warps = 4; opt.num_warps = 4;
// arguments // arguments