[CODEGEN] Performance improvement on A100 (#125)

Improved codegen for the Ampere GPUs.

    * Make the layout pass recognize the multistage pipelined pattern.
    * Now the pipeline pass can automate the multistage pipelining transformation.
    * Remove extra barriers (from the prefetch pass & WAR) on Ampere.
    * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
daadaada
2021-06-21 14:25:13 +08:00
committed by Philippe Tillet
parent 5a51f3e529
commit d8d6b715c8
21 changed files with 855 additions and 174 deletions

View File

@@ -141,10 +141,19 @@ struct double_buffer_info_t {
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
@@ -158,6 +167,10 @@ public:
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
@@ -169,6 +182,7 @@ private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;

View File

@@ -21,7 +21,7 @@ namespace codegen{
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);

View File

@@ -223,6 +223,10 @@ private:
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;

View File

@@ -32,6 +32,8 @@ class shared_layout;
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
@@ -40,6 +42,7 @@ private:
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
@@ -47,14 +50,16 @@ private:
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};

View File

@@ -14,11 +14,13 @@ namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform

View File

@@ -1,9 +1,12 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
@@ -13,9 +16,11 @@ class target;
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}

View File

@@ -832,6 +832,7 @@ public:
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
void set_N(int n) { N_ = n; }
private:
int N_;

View File

@@ -1,5 +1,3 @@
#pragma once
#ifndef _TRITON_IR_PRINT_H_
#define _TRITON_IR_PRINT_H_
@@ -9,8 +7,14 @@ namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
void print(module &mod, std::ostream& os);
void print(function &func, std::ostream& os);
void print(basic_block &bb, std::ostream& os);
void print(instruction &instr, std::ostream& os);
}
}

View File

@@ -7,6 +7,7 @@
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
// #include "triton/ir/type.h"
namespace triton{
namespace codegen{
@@ -273,6 +274,81 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
static bool is_smem(ir::value* v) {
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v))
return true;
else
return false;
}
/// param:
/// value_1: next_value
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
std::vector<ir::value*>& values_0, ir::value*& value_1) {
ir::value* next = phi;
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
// smem from previous bb & phi/smem from current bb
ir::value* c0 = cphi->get_incoming_value(0);
ir::value* c1 = cphi->get_incoming_value(1);
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) {
assert(cbb0 == bb0);
values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1;
continue;
} else {
if (is_smem(c1)) {
value_1 = c1;
assert(cbb1 == bb1);
return true;
} else {
return false;
}
}
} else
return false;
}
}
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
auto* phi = dynamic_cast<ir::phi_node*>(v);
// if the phi node is nested
if (!phi)
return;
ir::basic_block *bb0 = phi->get_incoming_block(0);
ir::basic_block *bb1 = phi->get_incoming_block(1);
std::vector<ir::value*> values_0;
ir::value* value_1;
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
return;
// double-buffer is a special case
if (values_0.size() == 1)
return;
// compute original values_0 input order
std::map<ir::value*, int> order;
int idx = 0;
for (ir::instruction* instr : *bb0) {
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
order[static_cast<ir::value*>(instr)] = idx++;
}
assert(order.size() == values_0.size() && "order size incorrect");
int curr_stages = values_0.size() + 1;
if (curr_stages > prev_stages) {
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
prev_stages = curr_stages;
}
}
shared_layout::shared_layout(data_layout *arg,
const std::vector<int>& axes,
@@ -284,7 +360,13 @@ shared_layout::shared_layout(data_layout *arg,
size_ = 0;
arg_layout_ = arg;
// N-stage buffering
int prev_stages = 0;
for (ir::value *v : values)
extract_N_bufferable(v, N_buffer_, prev_stages);
// double-buffering
if (!N_buffer_)
for(ir::value *v: values)
extract_double_bufferable(v, double_buffer_);
@@ -311,8 +393,22 @@ shared_layout::shared_layout(data_layout *arg,
size_ *= s;
if(double_buffer_)
size_ *= 2;
if (N_buffer_) {
size_ *= (N_buffer_->firsts.size() + 1);
}
}
int shared_layout::get_num_stages() const {
if (double_buffer_)
return 2;
if (N_buffer_)
return N_buffer_->firsts.size() + 1;
return 1;
}
size_t shared_layout::get_per_stage_elements() const {
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
}
/* -------------------------------- *
* ---- Layouts Inference Pass ---- *
@@ -403,7 +499,6 @@ void layouts::run(ir::module &mod) {
for(const auto& x: values_)
create(x.first, x.second);
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {

View File

@@ -26,7 +26,7 @@ namespace codegen {
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
// generate llvm code
llvm::LLVMContext ctx;
@@ -39,26 +39,27 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
codegen::analysis::align align;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::allocation allocation(&liveness);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get());
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
// codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target.get());
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
// run passes
dce.run(ir);
peephole.run(ir);
dce.run(ir);
// ir::print(ir, std::cout);
pipeline.run(ir);
dce.run(ir);
//ir::print(ir, std::cout);
// ir::print(ir, std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);

View File

@@ -212,18 +212,41 @@ void generator::visit_value(ir::value* v) {
return;
if(v->get_type()->is_block_ty()){
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
auto double_buffer = layout->get_double_buffer();
analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
// offset
Value *offset = nullptr;
if(double_buffer && v == double_buffer->phi)
offset = shared_off_[layout];
// base pointer
Value *ptr = shared_ptr_[layout];
if(double_buffer && v == double_buffer->latch)
if (n_buffer) {
// ptr = base (shared_ptr_[layout]) + smem_idx * size
// read_smem_idx
if (v == n_buffer->phi) {
ptr = shared_ptr_[layout];
}
// write_smem_idx
if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
int elements = write_smem_idx * layout->get_per_stage_elements();
ptr = gep(shared_pre_ptr_[layout], i32(elements));
} else if (v == n_buffer->latch) {
Value* write_smem_idx = write_smem_idx_[layout];
Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
ptr = gep(shared_pre_ptr_[layout], elements);
}
} else if (double_buffer) {
if(v == double_buffer->phi)
offset = shared_off_[layout];
if(v == double_buffer->latch)
ptr = shared_next_ptr_[layout];
else if(double_buffer && v == double_buffer->first)
else if(v == double_buffer->first)
ptr = shared_pre_ptr_[layout];
} // else do nothing
// what visit_dot & vist_cts & ... see
shmems_[v] = ptr;
// now only latches have offset (PHINode), only used by finalize_share_layout()
shoffs_[v] = offset;
}
}
@@ -1223,24 +1246,21 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
* \brief Code Generation for `mma.16816` (A100)
*/
//TODO: clean-up
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
const auto& shapes = dot->get_type()->get_block_shapes();
void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
for(indices_t idx: idxs_.at(dot)){
for(indices_t idx: idxs_.at(C)){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
fcs[key].push_back(vals_[D][idx]);
};
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0));
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1));
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::string a_trans = is_a_row ? "" : ".trans";
@@ -1264,8 +1284,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
int vec_a = 8;
int vec_b = 8;
Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
@@ -1276,7 +1294,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
std::map<std::pair<unsigned, unsigned>, Value*> hb;
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
builder_->SetInsertPoint(FirstBB->getTerminator());
@@ -1339,42 +1356,14 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
"{$0, $1, $2, $3}, "
"{$4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false);
"{$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
unsigned num_rep_0 = shapes[0] / layout->spt(0);
unsigned num_rep_1 = shapes[1] / layout->spt(1);
for(unsigned K = 0; K < NK; K += 16)
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
if(ha.find({m, K}) == ha.end()){
Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false);
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
ha[{m, K}] = std::make_pair(ha0, ha1);
ha[{m, K+8}] = std::make_pair(ha2, ha3);
}
if(hb.find({n, K})==hb.end()){
Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b];
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false);
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
hb[{n, K}] = hb0;
hb[{n+1, K}] = hb2;
hb[{n, K+8}] = hb1;
hb[{n+1, K+8}] = hb3;
}
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
unsigned cols_per_thread = num_rep_0 * 2;
std::vector<size_t> idx = {
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
@@ -1389,16 +1378,145 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
}
};
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds =
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto register_lds2 =
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
} else
vals[{m, K}] = val;
};
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
Value* ptra;
if(K == 0 && is_prefetch){
if(inc == 0)
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
else
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
else
ptra = ptrs_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
"=r,=r,=r,=r,r", true);
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
};
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
Value* ptrb;
if(K == 0 && is_prefetch){
if(inc == 0)
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
else
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
else
ptrb = ptrs_b[offidx];
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
"=r,=r,=r,=r,r", true);
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
register_lds2(hb, n, K, inc, hb0, is_prefetch);
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
for(unsigned m = 0; m < num_rep_0; m++){
ha[{m, 0}].first = phi(fp16x2_ty, 2);
ha[{m, 0}].second = phi(fp16x2_ty, 2);
ha[{m, 8}].first = phi(fp16x2_ty, 2);
ha[{m, 8}].second = phi(fp16x2_ty, 2);
}
for(unsigned n = 0; n < num_rep_1; n+=2){
hb[{n, 0}] = phi(fp16x2_ty, 2);
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
hb[{n, 8}] = phi(fp16x2_ty, 2);
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, 0, 0, true);
for(unsigned n = 0; n < num_rep_1; n+=2)
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(CurrBB);
for(unsigned K = 0; K < NK; K += 16){
int NEXTK = (K + 16) % NK;
// prefetch A
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, NEXTK, 1, true);
// prefetch B
for(unsigned n = 0; n < num_rep_1; n+=2)
load_b(n, NEXTK, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
call_mma(m, n, K);
}
}
}
else{
for(unsigned K = 0; K < NK; K += 16)
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
if(ha.find({m, K}) == ha.end())
load_a(m, K, 0, false);
if(hb.find({n, K})==hb.end())
load_b(n, K, 0, false);
call_mma(m, n, K);
}
}
// write back
unsigned i = 0;
for(indices_t idx: idxs_.at(dot)){
for(indices_t idx: idxs_.at(C)){
std::vector<Value*> key(idx.size() - 2);
std::copy(idx.begin() + 2, idx.end(), key.begin());
if(i >= fcs.at(key).size())
i = 0;
vals_[dot][idx] = fcs.at(key)[i++];
vals_[C][idx] = fcs.at(key)[i++];
};
}
@@ -2252,8 +2370,35 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
void generator::visit_layout_shared(analysis::shared_layout* layout) {
Type* ty = cvt(layout->get_type());
PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
// double-buffered
if(layout->get_double_buffer()) {
if (layout->get_N_buffer()) {
// create pointers
shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout->get_N_buffer();
ir::phi_node *phi = info.phi;
BasicBlock *parent = bbs_.at(phi->get_parent());
if(parent->empty())
builder_->SetInsertPoint(parent);
else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
} else
builder_->SetInsertPoint(parent);
// create smem_idx
read_smem_idx_[layout] = phi(i32_ty, 2);
write_smem_idx_[layout] = phi(i32_ty, 2);
// create pointers
// ptr of the current iteration
shared_ptr_[layout] = phi(ptr_ty, 2);
// ptr of the next iteration
shared_next_ptr_[layout] = phi(ptr_ty, 2);
builder_->SetInsertPoint(current);
} else if(layout->get_double_buffer()) {
BasicBlock *current = builder_->GetInsertBlock();
auto info = *layout->get_double_buffer();
ir::phi_node *phi = info.phi;
@@ -2269,8 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
shared_off_[layout] = phi(i32_ty, 2);
shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
builder_->SetInsertPoint(current);
}
else{
} else{
size_t offset = alloc_->offset(layout);
shared_ptr_[layout] = gep(shmem_, i32(offset));
shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
@@ -2354,7 +2498,67 @@ void generator::init_idx(ir::value *v) {
}
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
if(shared->get_double_buffer()) {
if (auto n_buffer = shared->get_N_buffer()) {
// if (*_smem_idx == #stages-1) {
// *_smem_idx = 0;
// } else *_smem_idx++;
auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
// insert point
Value *idx = smem_idx[shared];
builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
PHINode *_ret = phi(i32_ty, 2);
Instruction *then_term = nullptr;
Instruction *else_term = nullptr;
Instruction *dummy = builder_->CreateRet(nullptr);
llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
dummy->removeFromParent();
builder_->SetInsertPoint(then_term);
Value *zero_smem_idx = i32(0);
builder_->SetInsertPoint(else_term);
Value *inc_smem_idx = add(idx, i32(1));
builder_->SetInsertPoint(_ret->getParent());
_ret->addIncoming(zero_smem_idx, then_term->getParent());
_ret->addIncoming(inc_smem_idx, else_term->getParent());
// update ir::bb -> llvm::bb mapping
bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
// idx = init_stage;
// loop: ...
if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
} else
throw std::runtime_error("Should be PHINode");
};
// read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
finalize_smem_idx(read_smem_idx_, 2);
finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
// finalize pointers
ir::phi_node *pn = n_buffer->phi;
BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
// %curr_ptr = phi %shared_pre_ptr, %next_ptr
// %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
} else
throw std::runtime_error("Should be PHINode");
BasicBlock *current = builder_->GetInsertBlock();
builder_->SetInsertPoint(header->getTerminator());
Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
builder_->SetInsertPoint(current->getTerminator());
assert(isa<PHINode>(shared_next_ptr_[shared]));
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
} else if(shared->get_double_buffer()) {
auto info = *shared->get_double_buffer();
ir::phi_node *phi = info.phi;
PHINode *ptr = (PHINode*)shmems_[phi];

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -20,9 +21,14 @@ namespace transform{
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
analysis::double_buffer_info_t* info = layout->get_double_buffer();
if(info)
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
return group_of(info->first, async_write);
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
if (v == info->phi)
return group_of(info->firsts[0], async_write);
else // prefetched value
return group_of(info->firsts[1], async_write);
}
std::vector<int> groups(phi->get_num_operands());
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
return *std::max_element(groups.begin(), groups.end());
@@ -69,12 +75,31 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
return ret;
}
bool membar::check_safe_war(ir::instruction* i) {
bool is_i_shared_block = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared();
bool is_i_double_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_i_n_buffered = is_i_shared_block &&
layouts_->get(i)->to_shared()->get_N_buffer();
if (is_i_double_buffered || is_i_n_buffered) {
// with async copy & prefetch_s disabled, WARs are not safe
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
return false;
else
return true;
}
return false;
}
void membar::transfer(ir::basic_block *block,
val_vec_t& async_write,
val_set_t& sync_write,
val_set_t& sync_read,
std::set<ir::value*>& safe_war,
bool& inserted, ir::builder& builder) {
std::vector<ir::async_wait_inst*> async_waits;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(ir::instruction *i: instructions){
if(dynamic_cast<ir::phi_node*>(i))
@@ -105,18 +130,14 @@ void membar::transfer(ir::basic_block *block,
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
async_waits.push_back(async_wait);
}
}
// RAW, WAR
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
layouts_->get(i)->to_shared() &&
layouts_->get(i)->to_shared()->get_double_buffer();
bool is_safe_war = check_safe_war(i);
// WAR barrier is not required when data is double-buffered
// TODO: how about other patterns, like WWAR?
if(!intersect_with(read, sync_write).empty() ||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
// force WAR barrier on A100
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
builder.set_insert_point(i);
barrier = (ir::barrier_inst*)builder.create_barrier();
inserted = true;
@@ -132,7 +153,41 @@ void membar::transfer(ir::basic_block *block,
sync_read.clear();
}
sync_read.insert(read.begin(), read.end());
}
// coalesce barriers
// fixme: to support more general cases
if (async_waits.size() == 2) {
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
for (int idx=0; idx<async_waits.size()-1; ++idx) {
ir::async_wait_inst *first_async_wait = async_waits[idx];
std::vector<ir::instruction*> to_erase;
ir::basic_block::inst_list_t instructions = block->get_inst_list();
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
ir::instruction *i = *iter;
if (static_cast<ir::instruction*>(first_async_wait) == i) {
// peak next 5 instructions
auto peak_iter = std::next(iter);
if (std::distance(peak_iter, instructions.end()) >= 5) {
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
int first_n = first_async_wait->get_N();
int second_n = second_async_wait->get_N();
to_erase.push_back(second_async_wait);
to_erase.push_back(second_bar);
first_async_wait->set_N(second_n);
}
} else
break;
for (ir::instruction *i : to_erase)
block->erase(i);
}
}
}
}
}
@@ -144,7 +199,7 @@ void membar::run(ir::module &mod) {
std::set<ir::value*> safe_war;
for(const auto& x: layouts_->get_all()){
analysis::shared_layout* layout = x.second->to_shared();
if(!layout || !layout->get_double_buffer())
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
continue;
for(ir::value *v: layout->get_values())
if(v != layout->get_double_buffer()->phi){
@@ -153,7 +208,6 @@ void membar::run(ir::module &mod) {
}
for(ir::function *fn: mod.get_function_list()){
// TODO: (dyan) we need DominatorTree here.
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
std::map<ir::basic_block*, val_vec_t> async_writes;
std::map<ir::basic_block*, val_set_t> sync_writes;

View File

@@ -23,6 +23,60 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
phis.insert(phi_op);
return;
}
if (dynamic_cast<ir::instruction*>(op))
get_induction_vars(op, phis);
}
}
/// Returns phi_val if sees a phi node
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi_val;
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_val(builder, op, phi_val));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
@@ -41,6 +95,28 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
return ret;
}
/// moving the prev phi vals to the next iteration
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
for (auto& [phi, val] : prev_phi_vals) {
// TODO: handling nested phis
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
}
}
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
next_load_ivs[phi] = next_k;
} else
throw std::runtime_error("must be phi");
}
}
void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
@@ -60,6 +136,8 @@ void pipeline::run(ir::module &mod) {
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
@@ -70,6 +148,106 @@ void pipeline::run(ir::module &mod) {
assert(block_br);
assert(header_br);
ir::type* ty = load->get_type();
// multi-stage pipe
if (has_copy_async_ && num_stages > 2) {
ir::value* header_cond = header_br->get_cond();
ir::value* block_cond = block_br->get_cond();
// 1. collect induction variables
std::set<ir::phi_node*> induction_vars;
get_induction_vars(block_cond, induction_vars);
std::vector<ir::value*> first_ptrs(num_stages-1);
std::vector<ir::value*> first_loads(num_stages-1);
std::vector<ir::value*> first_masks(num_stages-1);
std::vector<ir::value*> loop_conds(num_stages-1);
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// note: we assume that ptr & other values only depend on ptr & iv (phis)
// TODO: can we just add all phis here?
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
for (ir::phi_node* iv : induction_vars)
prev_phi_vals[iv] = iv->get_value_for_block(header);
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
loop_conds[0] = header_cond;
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
update_prev_phi_vals(builder, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
}
// create new phis for induction variables
builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (ir::phi_node* iv : induction_vars) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
// phi node
ptr->set_incoming_value(0, first_ptrs.back());
builder.set_insert_point(block->get_first_non_phi());
// nested phis for load
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
for (auto& pn : new_load_phis)
pn = builder.create_phi(ty, 2);
for (int i=0; i<num_stages-2; ++i) {
new_load_phis[i]->add_incoming(first_loads[i], header);
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
}
new_load_phis.back()->add_incoming(first_loads.back(), header);
new_load_phis.back()->add_incoming(next_load, block);
load->replace_all_uses_with(new_load_phis.front());
new_loads.push_back(new_load_phis.back());
// record first_loads to reorder them
preheader_loads.push_back({new_load_phis.front(), first_loads});
} else {
// pre-fetch first iteration
builder.set_insert_point(header->get_inst_list().back());
ir::value* first_ptr = ptr->get_value_for_block(header);
@@ -103,7 +281,22 @@ void pipeline::run(ir::module &mod) {
load->replace_all_uses_with(new_load);
new_loads.push_back(new_load);
}
}
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
// a0, b0, a1, b1, ...
if (!preheader_loads.empty()) {
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
builder.set_insert_point(header->get_inst_list().back());
for (int i=1; i<num_stages-1; ++i) {
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
ir::instruction* moved_load = original_load->clone();
builder.insert(moved_load);
original_load->replace_all_uses_with(moved_load);
}
}
}
// try to move dot_inst after loads
// for better overlap of io and compute

View File

@@ -25,13 +25,12 @@ static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::in
}
void prefetch::run(ir::module &mod) {
// 1. collect dot that can be prefethced
// 1. collect dots that can be prefethced
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16 & volta/turing
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
tgt_->as_nvidia()->sm() >= 80)
// Now only do prefetching when dot is fp16
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
@@ -56,16 +55,31 @@ void prefetch::run(ir::module &mod) {
// 1. in the loop header (first iteration)
builder.set_insert_point(loop_header->get_inst_list().back());
builder.create_barrier();
assert(a && b);
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
// 2. at the end of the loop body (next iteration)
builder.set_insert_point(loop_body->get_inst_list().back());
builder.create_barrier();
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
prefetched_vals_.insert(a->get_incoming_value(0));
prefetched_vals_.insert(b->get_incoming_value(0));
// nested phis
ir::value* next_a = a->get_incoming_value(1);
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
next_a = next_a_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_a);
ir::value* next_b = b->get_incoming_value(1);
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
next_b = next_b_phi->get_incoming_value(1);
}
prefetched_vals_.insert(next_b);
}
// move loads to the beginning of the loop

View File

@@ -324,7 +324,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device)
}
catch(exception::cuda::invalid_ptx const &){
//#ifdef TRITON_LOG_PTX_ERROR
std::cout << ptx << std::endl;
// std::cout << ptx << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
// exit(1);
//#endif

View File

@@ -77,6 +77,54 @@ void print(module &mod, std::ostream& os) {
}
}
void print(function &fn, std::ostream &os) {
//
}
void print(basic_block &bb, std::ostream &os) {
auto const &predecessors = bb.get_predecessors();
os << bb.get_name() << ":";
if(!predecessors.empty()){
os << " ";
os << "; preds = ";
auto const &predecessors = bb.get_predecessors();
for(ir::basic_block *pred: predecessors)
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
}
os << std::endl;
for(ir::instruction *inst: bb.get_inst_list()){
print(*inst, os);
}
}
void print(instruction &instr, std::ostream &os) {
instruction *inst = &instr;
os << " ";
if(!inst->get_type()->is_void_ty()){
os << instr.get_name();
os << " = ";
}
ir::type* type = inst->get_type();
os << inst->repr() << " " << type->repr();
ir::instruction::ops_t ops = inst->ops();
size_t num_ops = inst->get_num_operands();
if(num_ops > 0)
os << " ";;
for(unsigned i = 0; i < num_ops; i++){
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
os << x->repr();
else
os << ops[i]->get_name();
os << (i < num_ops - 1?", ":"");
}
os << ";";
// os << " (";
// for(ir::user* usr: inst->get_users())
// os << get_name(usr, cnt++) << ", " ;
// os << " )";
os << std::endl;
}
}
}

View File

@@ -33,7 +33,10 @@ void init_triton_driver(py::module &&m) {
CUdevice handle;
drv::dispatch::cuDeviceGet(&handle, dev_id);
return new drv::cu_device(handle, take_ownership);
}));
}))
.def("max_shared_memory", [](drv::cu_device *self) {
return self->max_shared_memory();
});
// host device
py::class_<drv::host_device, drv::device>(m, "host_device")
.def(py::init<>());
@@ -75,11 +78,11 @@ void init_triton_driver(py::module &&m) {
void init_triton_codegen(py::module &&m) {
m.def(
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
drv::module *mod;
drv::kernel *ker;
size_t shared_mem;
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
std::stringstream ss;
ir::print(ir, ss);
return std::make_tuple(mod, ker, shared_mem, ss.str());

View File

@@ -27,7 +27,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb)
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
# torch result
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b

View File

@@ -5,56 +5,69 @@ import torch
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
# 8 warp
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# # split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# # variable input
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
configs = [triton.Config(meta=META, num_warps=NWARP)]
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []
@@ -72,5 +85,5 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
assert triton.testing.allclose(th_c, tt_c)

View File

@@ -407,13 +407,14 @@ class CodeGenerator(ast.NodeVisitor):
class Binary:
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
# cache ir asm
self.ir_asm = ir_asm
self.module = module
self.kernel = kernel
self.shared_mem = shared_mem
self.num_warps = num_warps
self.num_stages = num_stages
self.sass = None
def asm(self, mode):
@@ -447,6 +448,13 @@ class CompilationError(Exception):
self.message += '\n Error: ' + str(err)
super().__init__(self.message)
class OutOfResources(Exception):
def __init__(self, required, limit, name):
self.message = f'out of resource: {name}'\
f'Required: {required}'\
f'Hardware limit: {limit}'
super().__init__(self.message)
class Kernel:
@staticmethod
@@ -513,7 +521,7 @@ class Kernel:
def __init__(self, fn):
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
# explicitly set device
torch.cuda.set_device(device.index)
# create IR module
@@ -535,10 +543,12 @@ class Kernel:
raise CompilationError(self.fn.src, node, e)
tt_device = _triton.driver.cu_device(device.index, False)
# Compile to machine code
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
if shared_mem > tt_device.max_shared_memory():
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
def __call__(self, *wargs, grid, num_warps=4, **meta):
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
# device inference
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
if len(tensor_idxs) == 0:
@@ -554,12 +564,12 @@ class Kernel:
attr_key = frozenset(attributes.items())
meta_key = frozenset(meta.items())
const_key = frozenset(constants.items())
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
cache = self.fn.cache
if key not in cache:
# compile and cache configuration if necessary
cache[key] = self._compile(
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
@@ -585,7 +595,7 @@ class Launcher:
class Autotuner:
def __init__(self, kernel, arg_names, configs, key):
if not configs:
self.configs = [Config(dict(), num_warps=4)]
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
@@ -603,7 +613,7 @@ class Autotuner:
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.meta)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call)
def __call__(self, *args, **meta):
@@ -616,7 +626,7 @@ class Autotuner:
config = self.cache[key]
else:
config = self.configs[0]
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
class JITFunction:
@@ -671,9 +681,10 @@ class JITFunction:
class Config:
def __init__(self, meta, num_warps=4):
def __init__(self, meta, num_warps=4, num_stages=2):
self.meta = meta
self.num_warps = num_warps
self.num_stages = num_stages
def autotune(configs, key):

View File

@@ -1,5 +1,6 @@
import torch
import os
from .code_gen import OutOfResources
try:
import triton._C.libtriton.cutlass as _cutlass
@@ -8,6 +9,15 @@ except ImportError:
_cutlass = None
has_cutlass = False
def catch_oor(kernel, pytest_handle=None):
try:
res = kernel()
except OutOfResources as e:
if pytest_handle:
pytest_handle.skip(str(e))
return None
return res
def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)