[CODEGEN] Add a pass to prefetch operands of dot if applicable. (#105)
* update membar pass when data is double buffered * Add instruction prefetch_s * prefetch tests pass (except the 1 warp case) * Fix the 1-warp bug * Add back prefetch files * Disable prefetch on a100 * Always add war barrier on sm>=80
This commit is contained in:
committed by
Philippe Tillet
parent
147675923e
commit
967e629c0c
@@ -11,6 +11,7 @@
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class PHINode;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
@@ -169,6 +170,7 @@ public:
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
@@ -209,16 +211,19 @@ private:
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
/// layout -> base ptr
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
/// offset for double-buffered layout
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
|
||||
/// Base shmem pointer of ir value
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
@@ -227,6 +232,11 @@ private:
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
/// PHI nodes
|
||||
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
@@ -44,14 +45,16 @@ private:
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc) {}
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
|
||||
target* tgt_;
|
||||
};
|
||||
|
||||
|
||||
|
22
include/triton/codegen/transform/prefetch.h
Normal file
22
include/triton/codegen/transform/prefetch.h
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
|
||||
// forward dclaration
|
||||
namespace triton::ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace triton::codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
class prefetch {
|
||||
target* tgt_;
|
||||
public:
|
||||
prefetch(target *tgt) : tgt_(tgt) {}
|
||||
void run(ir::module &module);
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
@@ -151,6 +151,7 @@ public:
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
value *create_prefetch_s(value *arg, int inc);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -143,7 +143,8 @@ enum value_id_t: unsigned {
|
||||
INST_ASYNC_WAIT,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
};
|
||||
|
||||
|
||||
|
@@ -666,6 +666,11 @@ private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
|
||||
bool is_prefetched_ = false;
|
||||
public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -821,6 +826,23 @@ private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
class prefetch_s_inst : public instruction {
|
||||
std::string repr_impl() const { return "prefetch_s"; }
|
||||
_TRITON_DEFINE_CLONE(prefetch_s_inst)
|
||||
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
|
||||
|
||||
/// inc_: 0->first, 1->latch
|
||||
int inc_ = 0;
|
||||
public:
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
int get_inc() const { return inc_; }
|
||||
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
//// On NVIDIA, implementation is such that
|
||||
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
//// so as to enable re-association on nv_static_program_idx which is constant
|
||||
|
@@ -70,6 +70,7 @@ class barrier_inst;
|
||||
class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
@@ -146,6 +147,7 @@ public:
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
|
@@ -12,6 +12,7 @@
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
@@ -44,11 +45,12 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
codegen::analysis::liveness liveness(&layouts);
|
||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||
codegen::analysis::allocation allocation(&liveness);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation);
|
||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get());
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||
// codegen::transform::reassociate reassociate;
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::transform::prefetch prefetch_s(target.get());
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||
// run passes
|
||||
dce.run(ir);
|
||||
@@ -90,8 +92,9 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
||||
swizzle.run(ir);
|
||||
liveness.run(ir);
|
||||
allocation.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
prefetch_s.run(ir);
|
||||
barriers.run(ir);
|
||||
// ir::print(ir, std::cout);
|
||||
isel.visit(ir, *llvm);
|
||||
mod = driver::module::create(dev, std::move(llvm));
|
||||
ker = driver::kernel::create(&*mod, name.c_str());
|
||||
|
@@ -1048,65 +1048,208 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
for(indices_t idx: idxs_.at(C))
|
||||
acc.push_back(vals_[D][idx]);
|
||||
|
||||
// update accumulators
|
||||
unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0);
|
||||
unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1);
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
if(has.find({m, K}) == has.end()){
|
||||
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
|
||||
|
||||
// update accumulators
|
||||
if (C->is_prefetched()) {
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
/// Cache lds value. If values are prefetched, create phi node
|
||||
auto register_lds =
|
||||
[&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, ir::value *v) {
|
||||
if (K == 0) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||
} else
|
||||
vals[{m, K}] = {val0, val1};
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int K, int inc) ->void {
|
||||
int offidx = (is_a_row ? K/4 : m) % num_ptr_a;
|
||||
Value* ptra;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
else {
|
||||
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptra = ptr_a[offidx];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha);
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
has[{m, K}] = {ha00, ha01};
|
||||
register_lds(has, m, K, inc, ha00, ha01, A);
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
has[{m, K+4}] = {ha10, ha11};
|
||||
register_lds(has, m, K+4, inc, ha10, ha11, A);
|
||||
else
|
||||
has[{m+1, K}] = {ha10, ha11};
|
||||
register_lds(has, m+1, K, inc, ha10, ha11, A);
|
||||
}
|
||||
}
|
||||
if(hbs.find({n, K}) == hbs.end()){
|
||||
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int K, int inc){
|
||||
int offidx = (is_b_row? n : K/4) % num_ptr_b;
|
||||
Value* ptrb;
|
||||
if(K==0){
|
||||
if(inc == 0) {
|
||||
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
else {
|
||||
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||
}
|
||||
}
|
||||
else
|
||||
ptrb = ptr_b[offidx];
|
||||
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
// record lds that needs to be moved
|
||||
if (K == 0 && inc == 1)
|
||||
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb);
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
hbs[{n, K}] = {hb00, hb01};
|
||||
register_lds(hbs, n, K, inc, hb00, hb01, B);
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
hbs[{n+1, K}] = {hb10, hb11};
|
||||
register_lds(hbs, n+1, K, inc, hb10, hb11, B);
|
||||
else
|
||||
hbs[{n, K+4}] = {hb10, hb11};
|
||||
register_lds(hbs, n, K+4, inc, hb10, hb11, B);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// create phis
|
||||
builder_->SetInsertPoint(curr_bb->getFirstNonPHI());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) {
|
||||
has[{m, 0}].first = phi(f16x2_ty, 2);
|
||||
has[{m, 0}].second = phi(f16x2_ty, 2);
|
||||
if (!is_a_row && vec_a>4) {
|
||||
has[{m+1, 0}].first = phi(f16x2_ty, 2);
|
||||
has[{m+1, 0}].second = phi(f16x2_ty, 2);
|
||||
}
|
||||
}
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) {
|
||||
hbs[{n, 0}].first = phi(f16x2_ty, 2);
|
||||
hbs[{n, 0}].second = phi(f16x2_ty, 2);
|
||||
if (is_b_row && vec_b>4) {
|
||||
hbs[{n+1, 0}].first = phi(f16x2_ty, 2);
|
||||
hbs[{n+1, 0}].second = phi(f16x2_ty, 2);
|
||||
}
|
||||
}
|
||||
|
||||
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||
for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2)
|
||||
load_a(m, 0, 0);
|
||||
for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1)
|
||||
load_b(n, 0, 0);
|
||||
|
||||
// update accumulators
|
||||
builder_->SetInsertPoint(curr_bb);
|
||||
for (unsigned K = 0; K < NK; K += 4) {
|
||||
int NEXTK = (K + 4) % NK;
|
||||
// prefetch A
|
||||
for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2)
|
||||
load_a(m, NEXTK, 1);
|
||||
// prefetch B
|
||||
for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1)
|
||||
load_b(n, NEXTK, 1);
|
||||
// tensor core ops
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++){
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
}
|
||||
}
|
||||
} else { // not prefetched
|
||||
for(unsigned K = 0; K < NK; K += 4)
|
||||
for(unsigned m = 0; m < num_m/2; m++)
|
||||
for(unsigned n = 0; n < num_n/2; n++) {
|
||||
if(has.find({m, K}) == has.end()){
|
||||
Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a];
|
||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||
int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K;
|
||||
Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak));
|
||||
Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3)));
|
||||
Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty);
|
||||
Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty);
|
||||
has[{m, K}] = {ha00, ha01};
|
||||
if(vec_a > 4){
|
||||
Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty);
|
||||
Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty);
|
||||
if(is_a_row)
|
||||
has[{m, K+4}] = {ha10, ha11};
|
||||
else
|
||||
has[{m+1, K}] = {ha10, ha11};
|
||||
}
|
||||
}
|
||||
if(hbs.find({n, K}) == hbs.end()){
|
||||
Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b];
|
||||
int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||
int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b);
|
||||
Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk));
|
||||
Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3)));
|
||||
Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty);
|
||||
Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty);
|
||||
hbs[{n, K}] = {hb00, hb01};
|
||||
if(vec_b > 4){
|
||||
Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty);
|
||||
Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty);
|
||||
if(is_b_row)
|
||||
hbs[{n+1, K}] = {hb10, hb11};
|
||||
else
|
||||
hbs[{n, K+4}] = {hb10, hb11};
|
||||
}
|
||||
}
|
||||
auto ha = has[{m, K}];
|
||||
auto hb = hbs[{n, K}];
|
||||
// arguments
|
||||
std::vector<size_t> idx = {
|
||||
(m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m,
|
||||
(m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m,
|
||||
(m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m,
|
||||
(m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m
|
||||
};
|
||||
std::vector<Value*> args = {ha.first, ha.second, hb.first, hb.second};
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
args.push_back(acc[idx[i]]);
|
||||
// execute mma
|
||||
Value *nc = call(mma, args);
|
||||
// unpack
|
||||
for(unsigned i = 0; i < 8; i++)
|
||||
acc[idx[i]] = extract_val(nc, {i});
|
||||
}
|
||||
}
|
||||
|
||||
// write back accumulators
|
||||
@@ -1827,6 +1970,40 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) {
|
||||
ir::value *v = i->get_operand(0);
|
||||
int inc = i->get_inc();
|
||||
if (inc == 0) {
|
||||
// If dot has not been visitied, do nothing.
|
||||
} else {
|
||||
// If dot has been visitied, insert prefetched lds
|
||||
assert(inc == 1);
|
||||
assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() &&
|
||||
"dot hasn't be visited");
|
||||
// sink lds & extract element
|
||||
// move lds & all uses to current location
|
||||
std::stack<Value*> work_stack;
|
||||
for (Value *value : prefetch_latch_to_bb_[v])
|
||||
work_stack.push(value);
|
||||
std::vector<Instruction*> dead_instrs;
|
||||
while (!work_stack.empty()) {
|
||||
Value *m = work_stack.top();
|
||||
work_stack.pop();
|
||||
|
||||
for (auto u : m->users())
|
||||
work_stack.push(u);
|
||||
|
||||
assert(isa<Instruction>(m));
|
||||
auto m_instr = static_cast<Instruction*>(m);
|
||||
|
||||
m_instr->removeFromParent();
|
||||
m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end()));
|
||||
assert(m_instr->getParent() == &*builder_->GetInsertBlock());
|
||||
builder_->SetInsertPoint(m_instr->getParent());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||
@@ -2144,6 +2321,7 @@ void generator::visit_basic_block(ir::basic_block * block) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
visit_value(i);
|
||||
}
|
||||
// Update ir bb -> llvm bb mapping
|
||||
bbs_[block] = builder_->GetInsertBlock();
|
||||
}
|
||||
|
||||
@@ -2247,6 +2425,8 @@ void generator::finalize_function(ir::function *fn) {
|
||||
for(ir::instruction *inst: block->get_inst_list())
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(inst))
|
||||
finalize_phi_node(phi);
|
||||
for(auto& x: lazy_phi_incs_)
|
||||
std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]);
|
||||
}
|
||||
|
||||
void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
|
@@ -96,7 +96,15 @@ void membar::transfer(ir::basic_block *block,
|
||||
}
|
||||
}
|
||||
// RAW, WAR
|
||||
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
|
||||
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
|
||||
layouts_->get(i)->to_shared() &&
|
||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||
// WAR barrier is not required when data is double-buffered
|
||||
// TODO: how about other patterns, like WWAR?
|
||||
if(!intersect_with(read, sync_write).empty() ||
|
||||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
|
||||
// force WAR barrier on A100
|
||||
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
@@ -133,6 +141,7 @@ void membar::run(ir::module &mod) {
|
||||
}
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// TODO: (dyan) we need DominatorTree here.
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
@@ -166,4 +175,4 @@ void membar::run(ir::module &mod) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
105
lib/codegen/transform/prefetch.cc
Normal file
105
lib/codegen/transform/prefetch.cc
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "triton/codegen/transform/prefetch.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
|
||||
/// find defs till phis
|
||||
static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
|
||||
ir::instruction *i = dynamic_cast<ir::instruction*>(v);
|
||||
if (!i || i->get_parent() != bb)
|
||||
return;
|
||||
if (i->get_id() == ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for (ir::value *op : i->ops())
|
||||
recursive_defs(op, bb, ret);
|
||||
}
|
||||
|
||||
void prefetch::run(ir::module &mod) {
|
||||
// 1. collect dot that can be prefethced
|
||||
std::vector<ir::dot_inst*> to_prefetch;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||
// Now only do prefetching when dot is fp16 & volta/turing
|
||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
|
||||
tgt_->as_nvidia()->sm() >= 80)
|
||||
return;
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
if (a && a->get_incoming_block(1) == a->get_parent() &&
|
||||
b && b->get_incoming_block(1) == b->get_parent())
|
||||
to_prefetch.push_back(dot);
|
||||
}
|
||||
});
|
||||
|
||||
assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// 2. do the prefetching
|
||||
for (ir::dot_inst* dot : to_prefetch) {
|
||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||
assert(a->get_incoming_block(0) == b->get_incoming_block(0));
|
||||
ir::basic_block *loop_header = a->get_incoming_block(0);
|
||||
ir::basic_block *loop_body = a->get_parent();
|
||||
|
||||
// mark as prefetched
|
||||
dot->set_prefetched(true);
|
||||
|
||||
// 1. in the loop header (first iteration)
|
||||
builder.set_insert_point(loop_header->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
assert(a && b);
|
||||
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
||||
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
||||
|
||||
// 2. at the end of the loop body (next iteration)
|
||||
builder.set_insert_point(loop_body->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
||||
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
||||
}
|
||||
|
||||
// move loads to the beginning of the loop
|
||||
if (tgt_->as_nvidia()->sm() < 80) {
|
||||
for (ir::function *fn : mod.get_function_list())
|
||||
for (ir::basic_block *bb : fn->blocks()) {
|
||||
// only apply to loop body
|
||||
if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
|
||||
continue;
|
||||
// record loads (& dependency) to move
|
||||
std::vector<ir::instruction*> loads;
|
||||
// record original inst order
|
||||
std::map<ir::instruction*, size_t> idx_map;
|
||||
size_t idx = 0;
|
||||
for (ir::instruction *inst : bb->get_inst_list()) {
|
||||
if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
|
||||
recursive_defs(i, bb, loads);
|
||||
idx_map[inst] = idx;
|
||||
idx++;
|
||||
}
|
||||
|
||||
// remove duplicates & keep the original input order
|
||||
std::sort(loads.begin(), loads.end());
|
||||
loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
|
||||
std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
|
||||
return idx_map[a] < idx_map[b];
|
||||
});
|
||||
|
||||
builder.set_insert_point(bb->get_first_non_phi());
|
||||
for (ir::instruction *i : loads) {
|
||||
bb->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace triton::codegen::transform
|
@@ -209,7 +209,8 @@ static std::map<int, int> vptx = {
|
||||
{10020, 65},
|
||||
{11000, 70},
|
||||
{11010, 71},
|
||||
{11020, 72}
|
||||
{11020, 72},
|
||||
{11030, 73},
|
||||
};
|
||||
|
||||
std::string cu_module::compile_llvm_module(llvm::Module* module, driver::device* device) {
|
||||
|
@@ -381,6 +381,9 @@ value *builder::create_async_wait(int N) {
|
||||
return insert(async_wait_inst::create(ctx_, N));
|
||||
}
|
||||
|
||||
value *builder::create_prefetch_s(value *arg, int inc) {
|
||||
return insert(prefetch_s_inst::create(ctx_, arg, inc));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@@ -832,6 +832,10 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string
|
||||
return new async_wait_inst(ctx, N, name, next);
|
||||
}
|
||||
|
||||
// prefetch_s
|
||||
prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, const std::string &name, instruction *next) {
|
||||
return new prefetch_s_inst(ctx, arg, inc, name, next);
|
||||
}
|
||||
|
||||
//// nv_dynamic_program_idx
|
||||
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
|
Reference in New Issue
Block a user