[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,6 +2,9 @@
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
namespace triton {
@@ -9,6 +12,7 @@ namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
@@ -29,18 +33,15 @@ namespace transform{
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
bool intersect(const val_set_t &X, const val_set_t &Y);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -16,6 +16,10 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
@@ -33,11 +37,12 @@ private:
private:
public:
peephole(target* tgt): tgt_(tgt) {}
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
void run(ir::module &module);
private:
bool has_copy_async_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend);
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public:
host_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership);
cu_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};

View File

@@ -35,6 +35,7 @@ public:
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
// Types
@@ -149,7 +150,7 @@ public:
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
value *create_async_wait();
value *create_async_wait(int N);
private:
context &ctx_;

View File

@@ -92,6 +92,7 @@ private:
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
@@ -803,14 +804,18 @@ public:
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; }
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
private:
int N_;
};
// On NVIDIA, implementation is such that

View File

@@ -98,6 +98,8 @@ private:
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
};
class function {

View File

@@ -30,11 +30,8 @@ private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr;
std::vector<size_t> times;
double total_time = 0;

View File

@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
}
return add_to_cache(x, result, max_contiguous_);
}
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}

View File

@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl;
}
if(is_recoalesce){
if(ptr.size() > 0){
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// if(order_[0] == 0)
// exit(1);
}
}
// std::cout << "---" << std::endl;
}
int data_layout::find_axis(int to_find) const {
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand();
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
unsigned i = order_[0];
int contiguous = 1;
if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));

View File

@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
}
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned vector = 1;
ir::value *ptrs = x->get_pointer_operand();
ir::value *msks = x->get_mask_operand();
unsigned in_vec = 1;
ir::value *arg = x->get_pointer_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
vector = in_layout->nts(in_order[0]);
in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
//
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1);
Value *max_phase = i32(8 / num_per_phase);
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes();
//
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1);
std::vector<Value*> shared;
for(size_t i = 0; i < n_shared; i++){
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld);
// phase
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
phase = urem(phase, max_phase);
// off
Value* off_0 = idx[in_order[0]];
off_0 = udiv(off_0, i32(vector));
off_0 = xor_(off_0, phase);
off_0 = mul(off_0 , i32(vector));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off = add(off_0, off_1);
//
shared.push_back(gep(shmems_[x], {off}));
}
//
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
auto idx = idxs_[ptrs][i];
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
std::vector<std::pair<Value*, int>> shared;
for(int i = 0; i < idxs_.at(arg).size(); i++){
unsigned id = i / min_vec;
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]);
Value *in_base = in_gep->getPointerOperand();
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector;
Value* out_base = shared[(i / per_thread_ld) % n_shared];
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]);
int out_off_1 = i % per_thread_ld;
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false);
std::string mod = (vector*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[msks][idx], out_base, in_base});
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(tmp.find(key) == tmp.end()){
if(CurrBB != FirstBB)
builder_->SetInsertPoint(FirstBB->getTerminator());
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
}
shared.push_back({tmp[key], off});
}
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
}
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes();
// default implementation
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_async_wait_inst(ir::async_wait_inst*) {
std::string asm_str = "cp.async.wait_all;";
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
add_barrier();
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage,
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty);
}

View File

@@ -15,114 +15,105 @@ namespace triton {
namespace codegen{
namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(!i->get_type()->is_tile_ty())
return;
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if(!layout)
return;
if(alloc_->has_offset(layout)){
unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(info)
return group_of(info->first, async_write);
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
}
else{
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
}
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
assert(inc_val);
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
continue;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
continue;
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
ret.insert(b);
}
}
}
else {
builder.set_insert_point(instr);
builder.create_barrier();
}
return ret;
}
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
}
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
bool& inserted, ir::builder& builder) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
if(dynamic_cast<ir::phi_node*>(i))
continue;
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
dynamic_cast<ir::masked_load_async_inst*>(i)){
async_write.push_back(i);
}
// record hazards
if(read_after_write || write_after_read) {
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear();
new_read_from.clear();
to_sync.clear();
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
sync_write.insert(i);
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
}
return std::make_pair(new_written_to, new_read_from);
// RAW, WAR
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
}
void membar::run(ir::module &mod) {
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from;
std::vector<ir::instruction*> to_sync;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
size_t n_inserted_im1 = 0;
bool done = false;
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
std::map<ir::basic_block*, val_set_t> sync_reads;
std::list<ir::value *> pipelined;
bool inserted;
do{
inserted = false;
// find barrier location
for(ir::basic_block *block: rpo){
// written to
std::vector<interval_vec_t> pred_written_to;
for(ir::basic_block* pred: block->get_predecessors())
pred_written_to.push_back(written_to[pred]);
// read from
std::vector<interval_vec_t> pred_read_from;
for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]);
// apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
written_to[block] = result.first;
read_from[block] = result.second;
// join inputs
val_vec_t async_write;
val_set_t sync_write;
val_set_t sync_read;
val_set_t tmp;
for(ir::basic_block* pred: block->get_predecessors()){
for(ir::value* v: async_writes[pred])
if(tmp.insert(v).second)
async_write.push_back(v);
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
}
size_t n_inserted_i = insert_locs.size();
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
}
}while(inserted);
}
}

View File

@@ -1,7 +1,9 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
}
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}

View File

@@ -0,0 +1,116 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x))
return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0));
return false;

View File

@@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
// properties
int shared_total, shared_optin, shared_static;
int n_spills, n_reg;
CUdevice dev;
dispatch::cuCtxGetDevice(&dev);
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
}
}

View File

@@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
void cu_module::init_from_ptx(const std::string& ptx) {
// JIT compile source-code
// std::cout << ptx << std::endl;
try{
// // compile ptx with ptxas

View File

@@ -76,7 +76,7 @@ void host_stream::synchronize() {
hst_->args.clear();
}
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
auto hst = kernel->module()->hst();
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
char* params = new char[args_size];
@@ -113,13 +113,13 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_);
}
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config);
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
}
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
// convenience functions
//===----------------------------------------------------------------------===//
value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
@@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
}
value *builder::create_async_wait() {
return insert(async_wait_inst::create(ctx_));
value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_, N));
}
}

View File

@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
blocks_.reserve(num_reserved);
}
value* phi_node::get_value_for_block(basic_block * block) {
auto it = std::find(blocks_.begin(), blocks_.end(), block);
size_t n = std::distance(blocks_.begin(), it);
return get_incoming_value(n);
}
// Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!");
@@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next);
}
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next);
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, N, name, next);
}

View File

@@ -15,10 +15,10 @@
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/runtime/function.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
@@ -149,6 +149,7 @@ void kernel::init_ker(){
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
@@ -156,19 +157,24 @@ void kernel::init_ker(){
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get());
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(*ir_);
pipeline.run(*ir_);
dce.run(*ir_);
disassociate.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
if(target->is_gpu())
cts.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
coalesce.run(*ir_);
@@ -179,6 +185,11 @@ void kernel::init_ker(){
reassociate.run(*ir_);
cts.run(*ir_);
}
dce.run(*ir_);
// ir::print(*ir_, std::cout);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
@@ -187,8 +198,9 @@ void kernel::init_ker(){
swizzle.run(*ir_);
liveness.run(*ir_);
allocation.run(*ir_);
if(allocation.allocated_size() > dev_->max_shared_memory())
throw exception::out_of_shared_memory();
shared_mem_ = allocation.allocated_size();
// if(allocation.allocated_size() > dev_->max_shared_memory())
// throw exception::out_of_shared_memory();
barriers.run(*ir_);
isel.visit(*ir_, *llvm);
//if(res->spilled() > 256)
@@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
}
std::string kernel::get_asm(asm_mode_t mode) {
@@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, true);
stream, 5, 20);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
}

View File

@@ -2,58 +2,74 @@ import triton
import torch
# square benchmarks
nt = {False: 'n', True: 't'}
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'provider',
y_vals = ['torch', 'triton', 'cutlass'],
y_lines = ['Torch', 'Triton', 'CUTLASS'],
ylabel = 'TFLOPS',
loglog = False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
)\
for AT in [False, True] for BT in [False, True]
x_names=["M", "N", "K"],
x_vals=[512 * i for i in range(1, 16)],
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False, True] for BT in [False, True]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
if provider == 'torch':
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
if provider == 'triton':
if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row'
layout_b = 'column' if BT else 'row'
layout_a = "column" if AT else "row"
layout_b = "column" if BT else "row"
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
cmd = [
os.environ["CUTLASS_PROFILER"],
f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--verbose=false",
]
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv')
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
return None
if __name__ == '__main__':
if __name__ == "__main__":
bench_op.run()

View File

@@ -15,21 +15,21 @@ import distutils.spawn
import torch
def find_llvm():
versions = ['-10', '-10.0', '']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
versions = ["-10", "-10.0", ""]
supported = ["llvm-config{v}".format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return paths[0]
config = distutils.spawn.find_executable('llvm-config')
instructions = 'Please install llvm-10-dev'
config = distutils.spawn.find_executable("llvm-config")
instructions = "Please install llvm-10-dev"
if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions)
version = os.popen('{config} --version'.format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
raise RuntimeError("Could not find llvm-config. " + instructions)
version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
@@ -37,84 +37,84 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.1.0':
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
if cmake_version < "3.1.0":
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions:
self.build_extension(ext)
def build_extension(self, ext):
#self.debug = True
# self.debug = True
self.debug = False
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TUTORIALS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
"-DTORCH_CXX11_ABI=" + cxx11abi,
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
"-DLLVM_CONFIG=" + find_llvm(),
]
# configuration
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
build_args = ['--config', cfg]
cfg = "Debug" if self.debug else "Release"
build_args = ["--config", cfg]
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
cmake_args += ["-A", "x64"]
build_args += ["--", "/m"]
else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j4']
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ["--", "-j4"]
env = os.environ.copy()
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup(
name='triton',
version='1.0.0',
author='Philippe Tillet',
author_email='phil@openai.com',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch'],
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
name="triton",
version="1.0.0",
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=["numpy", "torch"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass={'build_ext': CMakeBuild},
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,
# for PyPI
keywords=['Compiler', 'Deep Learning'],
url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
keywords=["Compiler", "Deep Learning"],
url="https://github.com/ptillet/triton/",
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license
'Programming Language :: Python :: 3.6',
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
"Intended Audience :: Developers", # Define that your audience are developers
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License", # Again, pick a license
"Programming Language :: Python :: 3.6",
],
)

View File

@@ -2,29 +2,17 @@ import torch
import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block)
for mode in ["sdd", "dsd", "dds"]
for at in [False, True]
for bt in [False, True]
for block in [16, 32, 64]
],
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
)
def test_matmul(
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn(
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
@@ -32,9 +20,7 @@ def test_matmul(
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
)
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb)
@@ -49,7 +35,6 @@ def test_matmul(
# compare
assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize(
"BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
)
kp_mask = torch.randint(
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# compare
assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd(
input_scale=1.0,
tol=2e-2,
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
.to(dtype)
.cuda()
for _ in range(3)
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
attn_mask = torch.tril(
torch.ones(
@@ -129,11 +106,9 @@ def test_attention_fwd_bwd(
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out ** 2).mean()
loss = (attn_out**2).mean()
loss.backward()
grads = [query.grad, key.grad, value.grad]
@@ -148,17 +123,16 @@ def test_attention_fwd_bwd(
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = (torch_attn_out ** 2).mean()
torch_loss = (torch_attn_out**2).mean()
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention(
layout,
block: int,
@@ -168,12 +142,8 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
layout, block, "sdd", trans_a=False, trans_b=True
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,

View File

@@ -4,7 +4,7 @@ import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
@@ -17,14 +17,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
# # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
# # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -40,22 +40,26 @@ import torch
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
]),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0)
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK * TZ
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
if M is None:
M = TM
if N is None:
N = TN
if K is None:
K = TK * SPLITK
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
th_c = torch.matmul(a, b)

View File

@@ -1,6 +1,6 @@
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
__global__ void NAME(TYPE *A __readonly __noalias __aligned(16),
TYPE *B __readonly __noalias __aligned(16),
TYPE *C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
@@ -13,7 +13,7 @@
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
int *lut, int *locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
@@ -26,9 +26,9 @@
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int offlutm[TM] = blockidm * (TN / BLOCK) * 4;
int offlutn[TN] = blockidn * 4;
int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
@@ -44,8 +44,8 @@
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
@@ -97,8 +97,8 @@
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK;
TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
@@ -106,7 +106,7 @@
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
bool checkbn[TK, TN] = rbn [newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
@@ -119,8 +119,8 @@
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
for (int k = AS1; k > 0; k -= step) {
acc += a @b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
@@ -145,8 +145,8 @@
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
a = *? (checka)pa;
b = *? (checkb)pb;
}
TYPE c[TM, TN] = acc;
@@ -159,9 +159,9 @@
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4;
int rr_offlutn[TN] = rr_blockidn * 4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
@@ -171,28 +171,29 @@
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
bool checkc[TM, TN] = rcn [newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
if (lockid == 0) {
*? (checkc)pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
if (count == 0)
*? (checkc)pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}
}

View File

@@ -57,11 +57,11 @@ class _matmul(torch.autograd.Function):
lockid[current:last] = nlocks
maxid[current:last] = last - current
# segment size
segments[current:current+d] = seg_max
segments[current:current + d] = seg_max
if r < seg_min and not isempty:
segments[current+d-1] += r
segments[current + d - 1] += r
if r >= seg_min or isempty:
segments[current+d] = r
segments[current + d] = r
current = last
col_idx += 1
offsets = torch.zeros_like(segments)
@@ -85,7 +85,7 @@ class _matmul(torch.autograd.Function):
superblocks = libtriton.superblock(layout.type(torch.int32), start_width)
luts, widths, packs = [], [], []
for size, nnz in superblocks:
width = nnz.shape[0] // (size*size)
width = nnz.shape[0] // (size * size)
h = nnz[:, 0]
i = nnz[:, 1]
j = nnz[:, 2]
@@ -98,8 +98,7 @@ class _matmul(torch.autograd.Function):
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
@@ -120,39 +119,32 @@ class _matmul(torch.autograd.Function):
if not is_16_multiple:
raise ValueError('Reduction size for SDD must be a multiple of 16')
# create kernel
total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
for lut, width, pack in zip(luts, widths, packs):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
defines = {
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
# create output
locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device)
# maximum grid size is 65535
# so operation might be decomposed into multiple
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
lut.data_ptr(), locks.data_ptr(), num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
@@ -164,7 +156,7 @@ class _matmul(torch.autograd.Function):
# Given a binary layout of 0s and 1s,
# Construct look-up table for efficient execution on GPUs
@staticmethod
def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx):
def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx):
# load-balancing
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
segments = _empty.clone()
@@ -199,17 +191,17 @@ class _matmul(torch.autograd.Function):
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets))
idx = transform(nnz[:, 2]*block)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
idx = transform(nnz[:, 2] * block)
xincs = idx.clone()
xincs[1:] -= idx[:-1]
# divide block into multiple steps
div = block // step
xincs = xincs.view(-1, 1).repeat(1, div)
xincs[:, 1:] = step
xincs[:, 0 ] -= (div-1)*step
xincs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
xincs[offsets[segments>0], 0] = idx[offsets[segments>0]]
xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]]
xincs = xincs.view(-1)
# block-mode input increments
if trans:
@@ -224,23 +216,23 @@ class _matmul(torch.autograd.Function):
widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
widx = widx
wincs = widx*block*block
wincs[1:] -= widx[:-1]*block*block
wincs = widx * block * block
wincs[1:] -= widx[:-1] * block * block
wincs = wincs.view(-1, 1).repeat(1, div)
if trans:
wincs[:, 1:] = step
wincs[:, 0] -= (div-1)*step
wincs[:, 0] -= (div - 1) * step
else:
wincs[:, 1:] = step*block
wincs[:, 0] -= (div - 1)*step*block
wincs[offsets[segments>0], 0] = widx[offsets[segments>0]]
wincs[:, 1:] = step * block
wincs[:, 0] -= (div - 1) * step * block
wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]]
wincs = wincs.view(-1)
# adjust offset and segment size
offsets *= 2*div
offsets *= 2 * div
segments *= div
# create header
width = column.size(0)
offsets += 6*width
offsets += 6 * width
header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
@@ -252,8 +244,7 @@ class _matmul(torch.autograd.Function):
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
@@ -266,19 +257,12 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
defines = {
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel', 'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
@@ -286,19 +270,15 @@ class _matmul(torch.autograd.Function):
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device)
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
@@ -311,19 +291,12 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
defines = {
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
@@ -331,28 +304,19 @@ class _matmul(torch.autograd.Function):
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
@@ -378,13 +342,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
@@ -392,7 +356,6 @@ class _matmul(torch.autograd.Function):
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
@@ -412,7 +375,8 @@ class matmul:
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
@@ -425,7 +389,7 @@ class matmul:
db_lut, db_num_locks, db_width, db_packs)
return self.lut_cache[key]
def __init__(self, layout, block, mode, trans_a = False, trans_b = False):
def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
# look-up table cache
@@ -455,9 +419,7 @@ class matmul:
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
db_width, db_packs)
return c

View File

@@ -1,9 +1,9 @@
#define STM 8
#define STN 8
__global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
__global__ void matmul(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
int M,
int N,
@@ -11,7 +11,7 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int lda __multipleof(LDA_POW2_DIV),
int ldb __multipleof(LDB_POW2_DIV),
int ldc __multipleof(LDC_POW2_DIV),
int* locks) {
int *locks) {
// prologue
int pid = get_program_id(0);
int pidz = get_program_id(2);
@@ -19,30 +19,30 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int gridn = (N + TN - 1) / TN;
// swizzle for better L2 performance
int width = STM*gridn;
int width = STM * gridn;
int stm = pid / width;
int RSTM = min(gridm - stm*STM, STM);
int stn = (pid % width) / (RSTM*STN);
int RSTN = min(gridn - stn*STN, STN);
int RSTM = min(gridm - stm * STM, STM);
int stn = (pid % width) / (RSTM * STN);
int RSTN = min(gridn - stn * STN, STN);
int laneid = pid % (RSTM * RSTN);
int lanem = laneid / RSTN;
int lanen = laneid % RSTN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int pidm = stm * STM + lanem;
int pidn = stn * STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / TZ;
K = K / SPLITK;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
TYPE* pa[TM, TK] = A + offa;
TYPE* pb[TK, TN] = B + offb;
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN;
TYPE *pa[TM, TK] = A + offa;
TYPE *pb[TK, TN] = B + offb;
// prefetches operands
bool checka[TM, TK] = rk[newaxis, :] < K;
bool checka[TM, TK] = rk [newaxis, :] < K;
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
@@ -51,18 +51,18 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
#if (IS_TK_DIV_K==1)
for (int k = K; k > 0; k -= TK) {
#if (IS_TK_DIV_K == 1)
bool checkk[TK] = k > TK;
#else
bool checkk[TK] = rk < k - TK;
#endif
bool checka[TM, TK] = checkk[newaxis, :];
bool checka[TM, TK] = checkk [newaxis, :];
bool checkb[TK, TN] = checkk[:, newaxis];
acc += a @ b;
#if (IS_TK_DIV_K==1)
a = *?(checka)pa;
b = *?(checkb)pb;
acc += a @b;
#if (IS_TK_DIV_K == 1)
a = *? (checka)pa;
b = *? (checkb)pb;
#else
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
@@ -76,22 +76,23 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
// epilogue
int rcm[TM] = pidm * TM + 0 ... TM;
int rcn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (SPLITK == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
if (count == 0)
*? (checkc)pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -3,29 +3,32 @@ import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
_locks = dict()
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels:
defines = {
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
int(is_tk_div_k)
"TYPE": dtype,
"STRIDE_AM": "lda" if is_a_row else "1",
"STRIDE_AK": "1" if is_a_row else "lda",
"STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
}
_matmul._kernels[key] = triton.kernel(_matmul.src,
_matmul._kernels[key] = triton.kernel(
_matmul.src,
device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K'])
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
alpha = 1.0
args = [
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
kernel(*args, grid=grid)
return c

View File

@@ -1,21 +1,33 @@
import torch
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
ret = torch.empty(
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
ret[:, idx, :, :] = x[
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
]
return ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
return torch.allclose(x, y, atol=atol, rtol=rtol)
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True)
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
def __init__(
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
):
self.x_names = x_names
self.x_vals = x_vals
self.y_name = y_name
@@ -44,6 +59,7 @@ class Benchmark:
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
@@ -53,26 +69,31 @@ class Mark:
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
row = [
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.y_vals
]
df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names)
xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear')
plot.set_yscale('log' if bench.loglog else 'linear')
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot):
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper

View File

@@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
TYPE anext[TM, TK] = *?(checka)pa;
TYPE bnext[TK, TN] = *?(checkb)pb;
acc += a @ b;
a = anext;
b = bnext;
// __debug_barrier();
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
@@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128";
opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ;
opt.defines["TK"] = "64" ;
opt.defines["TZ"] = "1";
opt.num_warps = 4;
// arguments