[CODEGEN] Major performance improvements on A100 (#70)

Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
Philippe Tillet
2021-02-21 15:19:39 -08:00
committed by Philippe Tillet
parent 045ab5d62a
commit 5b83259592
31 changed files with 1331 additions and 1115 deletions

View File

@@ -2,6 +2,9 @@
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
namespace triton {
@@ -9,6 +12,7 @@ namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
@@ -29,18 +33,15 @@ namespace transform{
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::vector<interval_t> interval_vec_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
void insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder);
bool intersect(const interval_vec_t &X, interval_t x);
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
void add_reference(ir::value *v, interval_vec_t &res);
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
std::map<triton::ir::instruction *, std::pair<bool, bool> > &insert_loc, std::set<triton::ir::value *> &safe_war, std::vector<triton::ir::instruction *> &to_sync);
bool intersect(const val_set_t &X, const val_set_t &Y);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):

View File

@@ -16,6 +16,10 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
@@ -33,11 +37,12 @@ private:
private:
public:
peephole(target* tgt): tgt_(tgt) {}
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
void run(ir::module &module);
private:
bool has_copy_async_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -29,7 +29,7 @@ public:
static driver::stream* create(backend_t backend);
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args = NULL, size_t args_size = 0) = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
@@ -44,7 +44,7 @@ class host_stream: public stream {
public:
host_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
@@ -55,7 +55,7 @@ public:
cu_stream(CUstream str, bool take_ownership);
cu_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size);
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};

View File

@@ -35,6 +35,7 @@ public:
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
// Types
@@ -149,7 +150,7 @@ public:
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
value *create_async_wait();
value *create_async_wait(int N);
private:
context &ctx_;

View File

@@ -92,6 +92,7 @@ private:
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
@@ -803,14 +804,18 @@ public:
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait"; }
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
private:
int N_;
};
// On NVIDIA, implementation is such that

View File

@@ -98,6 +98,8 @@ private:
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
};
class function {

View File

@@ -30,11 +30,8 @@ private:
high_resolution_clock::time_point _start;
};
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
{
// const driver::device * device = stream->context()->device();
size_t warmup = 10;
size_t repeat = 50;
timer tmr;
std::vector<size_t> times;
double total_time = 0;

View File

@@ -312,7 +312,6 @@ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst*
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl;
}
return add_to_cache(x, result, max_contiguous_);
}
@@ -527,8 +526,7 @@ void align::run(ir::module &mod) {
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
// ir::for_each_value(mod, [this](ir::value* v) {
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
// });
}

View File

@@ -118,15 +118,6 @@ data_layout::data_layout(id_t id,
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// std::cout << order_[0] << " " << order_[1] << std::endl;
}
if(is_recoalesce){
if(ptr.size() > 0){
// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl;
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
// if(order_[0] == 0)
// exit(1);
}
}
// std::cout << "---" << std::endl;
}
int data_layout::find_axis(int to_find) const {
@@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps,
ir::value *ptr = nullptr;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
ptr = st->get_pointer_operand();
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
unsigned i = order_[0];
int contiguous = 1;
if(ptr){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));

View File

@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
}
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
unsigned vector = 1;
ir::value *ptrs = x->get_pointer_operand();
ir::value *msks = x->get_mask_operand();
unsigned in_vec = 1;
ir::value *arg = x->get_pointer_operand();
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline();
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
auto out_order = out_layout->get_order();
auto in_order = in_layout->get_order();
// tiles
if(out_order == in_order)
vector = in_layout->nts(in_order[0]);
in_vec = in_layout->nts(in_order[0]);
int out_vec = swizzle_->get_vec(out_layout);
int min_vec = std::min<int>(out_vec, in_vec);
int s = std::max<int>(out_vec / in_vec, 1);
//
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1);
Value *max_phase = i32(8 / num_per_phase);
int per_phase = swizzle_->get_per_phase(out_layout);
int max_phase = swizzle_->get_max_phase(out_layout);
//
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
auto shapes = x->get_type()->get_tile_shapes();
//
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1);
std::vector<Value*> shared;
for(size_t i = 0; i < n_shared; i++){
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld);
// phase
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
phase = urem(phase, max_phase);
// off
Value* off_0 = idx[in_order[0]];
off_0 = udiv(off_0, i32(vector));
off_0 = xor_(off_0, phase);
off_0 = mul(off_0 , i32(vector));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off = add(off_0, off_1);
//
shared.push_back(gep(shmems_[x], {off}));
}
//
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
auto idx = idxs_[ptrs][i];
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
std::map<std::pair<int, int>, Value*> tmp;
std::vector<std::pair<Value*, int>> shared;
for(int i = 0; i < idxs_.at(arg).size(); i++){
unsigned id = i / min_vec;
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]);
Value *in_base = in_gep->getPointerOperand();
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector;
Value* out_base = shared[(i / per_thread_ld) % n_shared];
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]);
int out_off_1 = i % per_thread_ld;
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false);
std::string mod = (vector*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[msks][idx], out_base, in_base});
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(tmp.find(key) == tmp.end()){
if(CurrBB != FirstBB)
builder_->SetInsertPoint(FirstBB->getTerminator());
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
off_0 = udiv(off_0, i32(min_vec));
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
off_0 = mul(off_0 , i32(min_vec));
Value* off = add(off_0, off_1);
if(CurrBB != FirstBB)
builder_->SetInsertPoint(CurrBB);
tmp[key] = gep(shmems_[x], {off});
}
shared.push_back({tmp[key], off});
}
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
auto idx = idxs_[arg][i];
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
in_base = cst ? in_base : in_gep;
// output ptr info
Value* out_base = shared[i].first;
int out_off = shared[i].second*2;
// asm
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
}
std::string asm_str = "cp.async.commit_group;";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
}
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_tile_shapes();
// default implementation
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
for(int i = 0; i < idxs_.at(arg).size(); i++){
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
add_barrier();
}
void generator::visit_async_wait_inst(ir::async_wait_inst*) {
std::string asm_str = "cp.async.wait_all;";
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
call(iasm);
add_barrier();
}
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(*ctx_);
Type *int_32_ty = Type::getInt32Ty(*ctx_);
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
Type *ptr_ty = ptr_ty(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage,
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
shmem_ = bit_cast(sh_mem_array, ptr_ty);
}

View File

@@ -15,114 +15,105 @@ namespace triton {
namespace codegen{
namespace transform{
bool membar::intersect(const interval_vec_t &X, interval_t x) {
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
bool left_intersect = y.first <= x.first && x.first < y.second;
bool right_intersect = y.first <= x.second && x.second < y.second;
return left_intersect || right_intersect;
});
}
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
return intersect(X, y);
});
}
void membar::add_reference(ir::value *v, interval_vec_t &res){
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(!i->get_type()->is_tile_ty())
return;
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
if(!layout)
return;
if(alloc_->has_offset(layout)){
unsigned offset = alloc_->offset(layout);
res.push_back(interval_t(offset, offset + layout->get_size()));
}
}
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
for(ir::value *op: i->ops())
add_reference(op, res);
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
std::set<ir::value*> incoming;
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
assert(inc_val);
if(incoming.insert(inc_val).second){
ir::basic_block *block = inc_val->get_parent();
builder.set_insert_point(block->get_inst_list().back());
if(type.first)
builder.create_async_wait();
if(type.second)
builder.create_barrier();
}
}
analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(info)
return group_of(info->first, async_write);
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
}
else{
builder.set_insert_point(instr);
builder.create_barrier();
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
}
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
membar::interval_vec_t result;
for(auto x: intervals)
for(interval_t i: x)
result.push_back(i);
return result;
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
val_set_t ret;
for(ir::value* a: as){
if(!a->get_type()->is_tile_ty())
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
if(!a_layout)
continue;
int a_start = alloc_->offset(a_layout);
int a_end = a_start + a_layout->get_size();
for(ir::value* b: bs){
if(!b->get_type()->is_tile_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
if(!b_layout)
continue;
int b_start = alloc_->offset(b_layout);
int b_end = b_start + b_layout->get_size();
if(a_start < b_end || b_start < a_end)
ret.insert(b);
}
}
return ret;
}
std::pair<membar::interval_vec_t,
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
const interval_vec_t &written_to,
const interval_vec_t &read_from,
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
std::vector<ir::instruction*>& to_sync) {
bool& inserted, ir::builder& builder) {
ir::basic_block::inst_list_t instructions = block->get_inst_list();
interval_vec_t new_written_to = written_to;
interval_vec_t new_read_from = read_from;
for(ir::instruction *i: instructions){
interval_vec_t read, written;
get_read_intervals(i, read);
get_written_intervals(i, written);
if(written.size())
to_sync.push_back(i);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
if(dynamic_cast<ir::phi_node*>(i))
continue;
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
dynamic_cast<ir::masked_load_async_inst*>(i)){
async_write.push_back(i);
}
// record hazards
if(read_after_write || write_after_read) {
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
insert_loc.insert({i, {copy_async_wait, barrier}});
new_written_to.clear();
new_read_from.clear();
to_sync.clear();
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
sync_write.insert(i);
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
// Get shared memory reads
std::set<ir::value*> read;
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
// RAW (async)
val_set_t tmp;
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
if(intersect_with(read, tmp).size()){
std::vector<int> groups(read.size());
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
int N = *std::max_element(groups.begin(), groups.end());
if(N < async_write.size()){
builder.set_insert_point(i);
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
}
return std::make_pair(new_written_to, new_read_from);
// RAW, WAR
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
}
// update state of asynchronous copies
if(async_wait){
int N = async_write.size() - async_wait->get_N();
async_write.erase(async_write.begin(), async_write.begin() + N);
}
// all the copy_to_shared and read from shared are synchronized after barrier
if(barrier){
sync_write.clear();
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
}
void membar::run(ir::module &mod) {
@@ -143,35 +134,33 @@ void membar::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, interval_vec_t> written_to;
std::map<ir::basic_block*, interval_vec_t> read_from;
std::vector<ir::instruction*> to_sync;
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
size_t n_inserted_im1 = 0;
bool done = false;
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;
std::map<ir::basic_block*, val_set_t> sync_reads;
std::list<ir::value *> pipelined;
bool inserted;
do{
inserted = false;
// find barrier location
for(ir::basic_block *block: rpo){
// written to
std::vector<interval_vec_t> pred_written_to;
for(ir::basic_block* pred: block->get_predecessors())
pred_written_to.push_back(written_to[pred]);
// read from
std::vector<interval_vec_t> pred_read_from;
for(ir::basic_block* pred: block->get_predecessors())
pred_read_from.push_back(read_from[pred]);
// apply transfer function
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
written_to[block] = result.first;
read_from[block] = result.second;
// join inputs
val_vec_t async_write;
val_set_t sync_write;
val_set_t sync_read;
val_set_t tmp;
for(ir::basic_block* pred: block->get_predecessors()){
for(ir::value* v: async_writes[pred])
if(tmp.insert(v).second)
async_write.push_back(v);
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
}
size_t n_inserted_i = insert_locs.size();
done = (n_inserted_im1 == n_inserted_i);
n_inserted_im1 = n_inserted_i;
}while(!done);
for(auto x: insert_locs){
insert_barrier(x.first, x.second, builder);
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
async_writes[block] = async_write;
sync_writes[block] = sync_write;
sync_reads[block] = sync_read;
}
}while(inserted);
}
}

View File

@@ -1,7 +1,9 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/analysis/layout.h"
namespace triton {
namespace codegen{
@@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
ir::value *ptr = ld->get_pointer_operand();
ir::value *msk = ld->get_mask_operand();
ir::value *val = ld->get_false_value_operand();
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
return false;
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
// return true;
}
@@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) {
bool was_modified = false;
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
was_modified = was_modified || rewrite_trans_phi(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// if(tgt_->as_nvidia()->sm() >= 80)
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);
}

View File

@@ -0,0 +1,116 @@
#include <iostream>
#include <algorithm>
#include "triton/codegen/transform/pipeline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{
namespace transform{
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i || i->get_parent() != block)
return;
if(i->get_id()==ir::INST_PHI)
return;
ret.push_back(i);
for(ir::user* u: i->get_users())
recursive_deps(u, block, ret);
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
false_value = masked_load->get_false_value_operand();
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
new_load->add_incoming(first_load, header);
new_load->add_incoming(next_load, block);
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
// try to move dot_inst after loads
// for better overlap of io and compute
struct move_config_t{
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
builder.insert(i);
}
}
}
}
}
}
}

View File

@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
inline bool is_cst(ir::value *x) {
if(dynamic_cast<ir::constant*>(x))
return true;
if(dynamic_cast<ir::make_range*>(x))
return true;
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
return is_cst(v->get_operand(0));
return false;

View File

@@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
// properties
int shared_total, shared_optin, shared_static;
int n_spills, n_reg;
CUdevice dev;
dispatch::cuCtxGetDevice(&dev);
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
}
}

View File

@@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
void cu_module::init_from_ptx(const std::string& ptx) {
// JIT compile source-code
// std::cout << ptx << std::endl;
try{
// // compile ptx with ptxas

View File

@@ -76,7 +76,7 @@ void host_stream::synchronize() {
hst_->args.clear();
}
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
auto hst = kernel->module()->hst();
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
char* params = new char[args_size];
@@ -113,13 +113,13 @@ void cu_stream::synchronize() {
dispatch::cuStreamSynchronize(*cu_);
}
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size) {
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config);
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
}
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {

View File

@@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
// convenience functions
//===----------------------------------------------------------------------===//
value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
@@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
}
value *builder::create_async_wait() {
return insert(async_wait_inst::create(ctx_));
value *builder::create_async_wait(int N) {
return insert(async_wait_inst::create(ctx_, N));
}
}

View File

@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
blocks_.reserve(num_reserved);
}
value* phi_node::get_value_for_block(basic_block * block) {
auto it = std::find(blocks_.begin(), blocks_.end(), block);
size_t n = std::distance(blocks_.begin(), it);
return get_incoming_value(n);
}
// Set incoming value
void phi_node::set_incoming_value(unsigned i, value *v){
assert(v && "PHI node got a null value!");
@@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next);
}
async_wait_inst::async_wait_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { }
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, name, next);
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
return new async_wait_inst(ctx, N, name, next);
}

View File

@@ -15,10 +15,10 @@
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/runtime/function.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
@@ -149,6 +149,7 @@ void kernel::init_ker(){
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
@@ -156,19 +157,24 @@ void kernel::init_ker(){
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get());
codegen::transform::peephole peephole(target.get(), &layouts);
codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps);
// run passes
dce.run(*ir_);
pipeline.run(*ir_);
dce.run(*ir_);
disassociate.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
if(target->is_gpu())
cts.run(*ir_);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
coalesce.run(*ir_);
@@ -179,6 +185,11 @@ void kernel::init_ker(){
reassociate.run(*ir_);
cts.run(*ir_);
}
dce.run(*ir_);
// ir::print(*ir_, std::cout);
align.run(*ir_);
axes.run(*ir_);
layouts.run(*ir_);
peephole.run(*ir_);
dce.run(*ir_);
align.run(*ir_);
@@ -187,8 +198,9 @@ void kernel::init_ker(){
swizzle.run(*ir_);
liveness.run(*ir_);
allocation.run(*ir_);
if(allocation.allocated_size() > dev_->max_shared_memory())
throw exception::out_of_shared_memory();
shared_mem_ = allocation.allocated_size();
// if(allocation.allocated_size() > dev_->max_shared_memory())
// throw exception::out_of_shared_memory();
barriers.run(*ir_);
isel.visit(*ir_, *llvm);
//if(res->spilled() > 256)
@@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co
for(size_t i = 0; i < 3; i++)
grid[i] = (i < _grid.size()) ? _grid[i] : 1;
// enqueue
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size);
stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_);
}
std::string kernel::get_asm(asm_mode_t mode) {
@@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_
while(grid.size() < 3)
grid.push_back(1);
double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); },
stream, true);
stream, 5, 20);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
}

View File

@@ -2,58 +2,74 @@ import triton
import torch
# square benchmarks
nt = {False: 'n', True: 't'}
nt = {False: "n", True: "t"}
square_confs = [
triton.testing.Benchmark(
x_names = ['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144],
y_name = 'provider',
y_vals = ['torch', 'triton', 'cutlass'],
y_lines = ['Torch', 'Triton', 'CUTLASS'],
ylabel = 'TFLOPS',
x_names=["M", "N", "K"],
x_vals=[512 * i for i in range(1, 16)],
y_name="provider",
y_vals=["torch", "triton", "cutlass"],
y_lines=["Torch", "Triton", "CUTLASS"],
ylabel="TFLOPS",
loglog=False,
plot_name = f'matmul-square-{nt[AT]}{nt[BT]}',
args = {'AT': False, 'BT': False, 'dtype': torch.float16}
)\
for AT in [False, True] for BT in [False, True]
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
args={"AT": AT, "BT": BT, "dtype": torch.float16},
) for AT in [False, True] for BT in [False, True]
]
@triton.testing.perf_report(square_confs)
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5):
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20):
import os
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5
if AT: a = a.t()
if BT: b = b.t()
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT:
a = a.t()
if BT:
b = b.t()
num_flops = 2 * M * N * K
if provider == 'torch':
if provider == "torch":
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
torch_tflops = num_flops / torch_ms * 1e-9
return torch_tflops
if provider == 'triton':
if provider == "triton":
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
triton_tflops = num_flops / triton_ms * 1e-9
return triton_tflops
if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ:
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
import subprocess
import tempfile
import pandas as pd
# run program specified by CUTLASS_PROFILER env variable
layout_a = 'column' if AT else 'row'
layout_b = 'column' if BT else 'row'
layout_a = "column" if AT else "row"
layout_b = "column" if BT else "row"
# create temporary file name
fd, fname = tempfile.mkstemp()
# run program and gets its output
cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \
'--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \
f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false']
cmd = [
os.environ["CUTLASS_PROFILER"],
f"--m={M}",
f"--n={N}",
f"--k={K}",
f"--A=f16:{layout_a}",
f"--B=f16:{layout_b}",
"--C=f16:column",
"--accum=f32",
"--operation=gemm",
"--verification-enabled=false",
f"--warmup-iterations={warmup}",
f"--profiling-iterations={rep}",
f"--output={fname}",
"--verbose=false",
]
# run cmd
subprocess.run(cmd, stdout=subprocess.PIPE)
# read CSV output
df_c = pd.read_csv(f'{fname}.gemm.csv')
cutlass_tflops = max(df_c['GFLOPs']) / 1e3
df_c = pd.read_csv(f"{fname}.gemm.csv")
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
return cutlass_tflops
return None
if __name__ == '__main__':
if __name__ == "__main__":
bench_op.run()

View File

@@ -15,21 +15,21 @@ import distutils.spawn
import torch
def find_llvm():
versions = ['-10', '-10.0', '']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
versions = ["-10", "-10.0", ""]
supported = ["llvm-config{v}".format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return paths[0]
config = distutils.spawn.find_executable('llvm-config')
instructions = 'Please install llvm-10-dev'
config = distutils.spawn.find_executable("llvm-config")
instructions = "Please install llvm-10-dev"
if config is None:
raise RuntimeError('Could not find llvm-config. ' + instructions)
version = os.popen('{config} --version'.format(config=config)).read()
raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions)
raise RuntimeError("Could not find llvm-config. " + instructions)
version = os.popen("{config} --version".format(config=config)).read()
raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions)
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=''):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
@@ -37,14 +37,14 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.1.0':
cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1))
if cmake_version < "3.1.0":
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions:
@@ -52,69 +52,69 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
# self.debug = True
self.debug = False
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# python directories
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR")
torch_include_dirs = include_paths(True)
torch_library_dirs = library_paths(True)
cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TUTORIALS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)),
'-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)),
'-DTORCH_CXX11_ABI=' + cxx11abi,
'-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton',
'-DLLVM_CONFIG=' + find_llvm()
"-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)),
"-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)),
"-DTORCH_CXX11_ABI=" + cxx11abi,
"-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton",
"-DLLVM_CONFIG=" + find_llvm(),
]
# configuration
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
build_args = ['--config', cfg]
cfg = "Debug" if self.debug else "Release"
build_args = ["--config", cfg]
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
cmake_args += ["-A", "x64"]
build_args += ["--", "/m"]
else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j4']
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
build_args += ["--", "-j4"]
env = os.environ.copy()
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))
subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
setup(
name='triton',
version='1.0.0',
author='Philippe Tillet',
author_email='phil@openai.com',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'],
install_requires=['numpy', 'torch'],
package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']},
name="triton",
version="1.0.0",
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"],
install_requires=["numpy", "torch"],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
include_package_data=True,
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
cmdclass={'build_ext': CMakeBuild},
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,
# for PyPI
keywords=['Compiler', 'Deep Learning'],
url='https://github.com/ptillet/triton/',
download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz',
keywords=["Compiler", "Deep Learning"],
url="https://github.com/ptillet/triton/",
download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz",
classifiers=[
'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Intended Audience :: Developers', # Define that your audience are developers
'Topic :: Software Development :: Build Tools',
'License :: OSI Approved :: MIT License', # Again, pick a license
'Programming Language :: Python :: 3.6',
"Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
"Intended Audience :: Developers", # Define that your audience are developers
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License", # Again, pick a license
"Programming Language :: Python :: 3.6",
],
)

View File

@@ -2,29 +2,17 @@ import torch
import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[
(mode, at, bt, block)
for mode in ["sdd", "dsd", "dds"]
for at in [False, True]
for bt in [False, True]
for block in [16, 32, 64]
],
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
)
def test_matmul(
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384
):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn(
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda")
b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
@@ -32,9 +20,7 @@ def test_matmul(
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = triton.ops.blocksparse.matmul(
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B
)
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb)
@@ -49,7 +35,6 @@ def test_matmul(
# compare
assert triton.testing.allclose(rc, tc)
@pytest.mark.parametrize(
"BLOCK, WIDTH",
[(block, width) for block in [32] for width in [256, 576, 1024, 1792]],
@@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
at_mask = torch.randint(
low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda"
)
kp_mask = torch.randint(
low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda"
)
at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda")
kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda")
kp_mask[kp_mask == 1.0] = float("-inf")
# triton result
op = triton.ops.blocksparse.softmax(layout, BLOCK)
@@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16):
# compare
assert triton.testing.allclose(ry, ty)
def test_attention_fwd_bwd(
input_scale=1.0,
tol=2e-2,
@@ -108,10 +88,7 @@ def test_attention_fwd_bwd(
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
.to(dtype)
.cuda()
for _ in range(3)
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3)
]
attn_mask = torch.tril(
torch.ones(
@@ -129,9 +106,7 @@ def test_attention_fwd_bwd(
query.retain_grad()
key.retain_grad()
value.retain_grad()
attn_out = triton_attention(
layout, block, attn_mask, query=query, key=key, value=value, scale=scale
)
attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale)
# ad hoc loss
loss = (attn_out**2).mean()
loss.backward()
@@ -153,12 +128,11 @@ def test_attention_fwd_bwd(
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
# print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...")
torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol)
def triton_attention(
layout,
block: int,
@@ -168,12 +142,8 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
layout, block, "sdd", trans_a=False, trans_b=True
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False
)
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False)
sparse_softmax = triton.ops.blocksparse.softmax(
layout,
block,

View File

@@ -4,7 +4,7 @@ import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
@@ -17,14 +17,14 @@ import torch
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
# # 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
# # 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
@@ -40,22 +40,26 @@ import torch
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
]),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
torch.manual_seed(0)
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK * TZ
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)]
if M is None:
M = TM
if N is None:
N = TN
if K is None:
K = TK * SPLITK
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
th_c = torch.matmul(a, b)

View File

@@ -186,7 +186,8 @@
else {
int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;

View File

@@ -98,8 +98,7 @@ class _matmul(torch.autograd.Function):
return luts, None, widths, packs
@staticmethod
def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, luts, num_locks, widths, packs):
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs):
if trans_c:
a, b = b, a
@@ -126,17 +125,12 @@ class _matmul(torch.autograd.Function):
num_lock = 1
key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
if key not in _matmul.sdd_cache:
defines = {'TM': block*pack, 'TN': block*pack,
'TMN': block*block*pack*pack,
'BLOCK': block,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
defines = {
'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK':
32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1',
'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc',
'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'
}
_matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)
kernel = _matmul.sdd_cache[key]
@@ -147,11 +141,9 @@ class _matmul(torch.autograd.Function):
# kernel calls
max_width = 49152
for off_width in range(0, width, max_width):
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), b.stride(2), block,
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(0),
AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock,
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0),
b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width,
lut.data_ptr(), locks.data_ptr(), num_lock,
grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
# save for backward pass
return c
@@ -252,8 +244,7 @@ class _matmul(torch.autograd.Function):
return lut, num_locks, width, None
@staticmethod
def _dds_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
@@ -266,19 +257,12 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dds_cache:
defines = {'TM': 128,
'TN': block,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1,
'STRIDE_BN': block if trans_b else 1,
'STRIDE_BK': 1 if trans_b else block,
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel',
'DDS': True}
defines = {
'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK':
1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dds_kernel', 'DDS': True
}
_matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dds_cache[key]
# output
@@ -288,17 +272,13 @@ class _matmul(torch.autograd.Function):
CS3 = AS2 if trans_c else BS2
locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
a.stride(2), block, c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
return c
@staticmethod
def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
spdims, block, lut, num_locks, width, packs):
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
# shapes / dtypes
AS0 = spdims[0]
AS1 = block * spdims[2 if trans_a else 1]
@@ -311,19 +291,12 @@ class _matmul(torch.autograd.Function):
# kernel
key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _matmul.dsd_cache:
defines = {'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TYPE': dtype,
'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1,
'STRIDE_BN': 'ldb' if trans_b else '1',
'STRIDE_BK': '1' if trans_b else 'ldb',
'STRIDE_CM': '1' if trans_c else 'ldc',
'STRIDE_CN': 'ldc' if trans_c else '1',
'NAME': 'dsd_kernel',
'DSD': True}
defines = {
'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
'1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
}
_matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
kernel = _matmul.dsd_cache[key]
# output
@@ -333,26 +306,17 @@ class _matmul(torch.autograd.Function):
CS3 = AS1 if trans_c else BS3
locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(),
block, b.stride(2), c.stride(2),
a.stride(0), b.stride(0), c.stride(0),
a.stride(1), b.stride(1), c.stride(1),
BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks,
grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
return c
fn = {'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)}
fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
@staticmethod
def forward(ctx, a, b, trans_a, trans_b, trans_c,
mode, spdims, block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block,
c_lut, c_num_locks, c_width, c_packs)
def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut,
da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks
@@ -378,13 +342,13 @@ class _matmul(torch.autograd.Function):
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs)
da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut,
ctx.da_num_locks, ctx.da_width, ctx.da_packs)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs)
db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut,
ctx.db_num_locks, ctx.db_width, ctx.db_packs)
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, None,\
@@ -392,7 +356,6 @@ class _matmul(torch.autograd.Function):
None, None, None, None, None, None
class matmul:
def make_lut(self, dtype, device):
key = (dtype, device)
if key in self.lut_cache:
@@ -412,7 +375,8 @@ class matmul:
elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device)
da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b,
device)
# DB look-up table
if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device)
@@ -455,9 +419,7 @@ class matmul:
a = matmul._pad_shape(a, self.mode == 'dsd')
b = matmul._pad_shape(b, self.mode == 'dds')
# execute
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False,
self.mode, self.spdims, self.block,
c_lut, c_num_locks, c_width, c_packs,
da_lut, da_num_locks, da_width, da_packs,
db_lut, db_num_locks, db_width, db_packs)
c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks,
db_width, db_packs)
return c

View File

@@ -33,7 +33,7 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int rn[TN] = pidn * TN + 0 ... TN;
// split-k for better parrallelism
K = K / TZ;
K = K / SPLITK;
int rk[TK] = 0 ... TK;
// pointers to operands
int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
@@ -79,19 +79,20 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16),
int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :];
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N;
#if (TZ==1)
#if (SPLITK == 1)
*? (checkc)pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1))
;
int count = *pcount;
if (count == 0)
*? (checkc)pc = c;
else
*? (checkc)pc = c + *? (checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(pcount, (count + 1) % SPLITK);
atomic_xchg(plock, 0);
#endif
}

View File

@@ -3,29 +3,32 @@ import triton
import os
class _matmul(torch.autograd.Function):
src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))
src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))
_DEFAULT_CONFIGS = [
({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4),
({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4),
({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4),
({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2),
({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
# ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4),
]
_CONFIGS = _DEFAULT_CONFIGS
@staticmethod
def largest_pow2_divisor(N):
if N % 8 == 0: return 8
if N % 4 == 0: return 4
if N % 2 == 0: return 2
if N % 8 == 0:
return 8
if N % 4 == 0:
return 4
if N % 2 == 0:
return 2
return 1
_locks = dict()
@@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function):
K, N = b.shape
c = torch.empty((M, N), dtype=dtype, device=device)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# kernel hash
is_a_row = a.stride(1) == 1
is_b_row = b.stride(1) == 1
@@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function):
ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
is_tk_div_k = K % 64 == 0
key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k)
key = (
device,
dtype,
is_a_row,
is_b_row,
lda_pow2_div,
ldb_pow2_div,
ldc_pow2_div,
is_tk_div_k,
)
if key not in _matmul._kernels:
defines = {
'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda',
'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV':
lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K':
int(is_tk_div_k)
"TYPE": dtype,
"STRIDE_AM": "lda" if is_a_row else "1",
"STRIDE_AK": "1" if is_a_row else "lda",
"STRIDE_BK": "ldb" if is_b_row else "1",
"STRIDE_BN": "1" if is_b_row else "ldb",
"LDA_POW2_DIV": lda_pow2_div,
"LDB_POW2_DIV": ldb_pow2_div,
"LDC_POW2_DIV": ldc_pow2_div,
"IS_TK_DIV_K": int(is_tk_div_k),
}
_matmul._kernels[key] = triton.kernel(_matmul.src,
_matmul._kernels[key] = triton.kernel(
_matmul.src,
device,
defines=defines,
autotune_vals=_matmul._CONFIGS,
autotune_key=['M', 'N', 'K'])
autotune_key=["M", "N", "K"],
)
kernel = _matmul._kernels[key]
# # locks for split-k
if device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# enqueue
alpha = 1.
args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()]
grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ]
alpha = 1.0
args = [
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
alpha,
M,
N,
K,
lda,
ldb,
ldc,
locks.data_ptr(),
]
grid = lambda opt: [
triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
1,
opt.SPLITK,
]
kernel(*args, grid=grid)
return c

View File

@@ -1,21 +1,33 @@
import torch
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
ret = torch.empty(
(x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block]
ret[:, idx, :, :] = x[
:, h, i * block : (i + 1) * block, j * block : (j + 1) * block
]
return ret
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
return torch.allclose(x, y, atol=atol, rtol=rtol)
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err < tol
def do_bench(fn, flops=0, warmup=10, rep=50):
start_event = torch.cuda.Event(enable_timing=True)
@@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50):
time_ms = start_event.elapsed_time(end_event) / rep
return time_ms
class Benchmark:
def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args):
def __init__(
self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args
):
self.x_names = x_names
self.x_vals = x_vals
self.y_name = y_name
@@ -44,6 +59,7 @@ class Benchmark:
self.plot_name = plot_name
self.args = args
class Mark:
def __init__(self, fn, benchmarks):
self.fn = fn
@@ -53,26 +69,31 @@ class Mark:
import matplotlib.pyplot as plt
import pandas as pd
import os
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
for x in bench.x_vals:
x_args = {x_name: x for x_name in bench.x_names}
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
row = [
self.fn(**x_args, **{bench.y_name: y}, **bench.args)
for y in bench.y_vals
]
df.loc[len(df)] = [x] + row
if with_plot and bench.plot_name:
xlabel = ' = '.join(bench.x_names)
xlabel = " = ".join(bench.x_names)
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
plot.set_xlabel(xlabel)
plot.set_ylabel(bench.ylabel)
plot.set_title(bench.plot_name)
plot.set_xscale('log' if bench.loglog else 'linear')
plot.set_yscale('log' if bench.loglog else 'linear')
plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png'))
df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv'))
plot.set_xscale("log" if bench.loglog else "linear")
plot.set_yscale("log" if bench.loglog else "linear")
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
def run(self, result_path, with_plot):
for bench in self.benchmarks:
self._run(bench, result_path, with_plot)
def perf_report(benchmarks):
wrapper = lambda fn: Mark(fn, benchmarks)
return wrapper

View File

@@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
bool checkb[TK, TN] = rk[:, newaxis] < K;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
TYPE anext[TM, TK] = *?(checka)pa;
TYPE bnext[TK, TN] = *?(checkb)pb;
acc += a @ b;
a = anext;
b = bnext;
// __debug_barrier();
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
@@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream,
opt.defines["TYPE"] = ty;
opt.defines["TM"] = "128";
opt.defines["TN"] = "128";
opt.defines["TK"] = "32" ;
opt.defines["TK"] = "64" ;
opt.defines["TZ"] = "1";
opt.num_warps = 4;
// arguments