[CODEGEN] Performance improvement on A100 (#125)
Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores.
This commit is contained in:
committed by
Philippe Tillet
parent
5a51f3e529
commit
d8d6b715c8
@@ -141,10 +141,19 @@ struct double_buffer_info_t {
|
|||||||
ir::phi_node* phi;
|
ir::phi_node* phi;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct N_buffer_info_t {
|
||||||
|
std::vector<ir::value*> firsts; // not necessarily ordered as input order
|
||||||
|
ir::value* latch;
|
||||||
|
ir::phi_node* phi;
|
||||||
|
std::map<ir::value*, int> firsts_idx;
|
||||||
|
};
|
||||||
|
|
||||||
|
// abstract for dot and coresponding smem values
|
||||||
class shared_layout: public data_layout {
|
class shared_layout: public data_layout {
|
||||||
private:
|
private:
|
||||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||||
|
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
shared_layout(data_layout *arg,
|
shared_layout(data_layout *arg,
|
||||||
@@ -158,6 +167,10 @@ public:
|
|||||||
size_t get_size() { return size_; }
|
size_t get_size() { return size_; }
|
||||||
ir::type* get_type() { return ty_; }
|
ir::type* get_type() { return ty_; }
|
||||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||||
|
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
|
||||||
|
int get_num_stages() const;
|
||||||
|
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
|
||||||
|
size_t get_per_stage_elements() const;
|
||||||
size_t get_num_per_phase() { return num_per_phase_; }
|
size_t get_num_per_phase() { return num_per_phase_; }
|
||||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||||
@@ -169,6 +182,7 @@ private:
|
|||||||
size_t size_;
|
size_t size_;
|
||||||
ir::type *ty_;
|
ir::type *ty_;
|
||||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||||
|
std::shared_ptr<N_buffer_info_t> N_buffer_;
|
||||||
size_t num_per_phase_;
|
size_t num_per_phase_;
|
||||||
ir::value* hmma_dot_a_;
|
ir::value* hmma_dot_a_;
|
||||||
ir::value* hmma_dot_b_;
|
ir::value* hmma_dot_b_;
|
||||||
|
@@ -21,7 +21,7 @@ namespace codegen{
|
|||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
|
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages,
|
||||||
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
||||||
|
|
||||||
|
|
||||||
|
@@ -223,6 +223,10 @@ private:
|
|||||||
std::map<ir::value*, Value*> shoffs_;
|
std::map<ir::value*, Value*> shoffs_;
|
||||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||||
|
/// idx for multi-stage pipeline
|
||||||
|
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||||
|
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||||
|
|
||||||
/// triton bb -> llvm bb
|
/// triton bb -> llvm bb
|
||||||
std::map<ir::value*, BasicBlock *> bbs_;
|
std::map<ir::value*, BasicBlock *> bbs_;
|
||||||
std::map<ir::value*, std::vector<int>> ords_;
|
std::map<ir::value*, std::vector<int>> ords_;
|
||||||
|
@@ -32,6 +32,8 @@ class shared_layout;
|
|||||||
|
|
||||||
namespace transform{
|
namespace transform{
|
||||||
|
|
||||||
|
class prefetch;
|
||||||
|
|
||||||
class membar {
|
class membar {
|
||||||
private:
|
private:
|
||||||
typedef std::pair<unsigned, unsigned> interval_t;
|
typedef std::pair<unsigned, unsigned> interval_t;
|
||||||
@@ -40,6 +42,7 @@ private:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||||
|
bool check_safe_war(ir::instruction* i);
|
||||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||||
@@ -47,14 +50,16 @@ private:
|
|||||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt):
|
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
|
||||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {}
|
transform::prefetch *prefetch, target* tgt):
|
||||||
|
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
analysis::liveness *liveness_;
|
analysis::liveness *liveness_;
|
||||||
analysis::layouts *layouts_;
|
analysis::layouts *layouts_;
|
||||||
analysis::allocation *alloc_;
|
analysis::allocation *alloc_;
|
||||||
|
transform::prefetch *prefetch_;
|
||||||
|
|
||||||
target* tgt_;
|
target* tgt_;
|
||||||
};
|
};
|
||||||
|
@@ -14,11 +14,13 @@ namespace transform {
|
|||||||
|
|
||||||
class pipeline {
|
class pipeline {
|
||||||
public:
|
public:
|
||||||
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
|
pipeline(bool has_copy_async, int num_stages)
|
||||||
|
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
|
||||||
void run(ir::module &module);
|
void run(ir::module &module);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool has_copy_async_;
|
bool has_copy_async_;
|
||||||
|
int num_stages_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace transform
|
} // namespace transform
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
|
||||||
// forward dclaration
|
// forward dclaration
|
||||||
namespace triton::ir{
|
namespace triton::ir{
|
||||||
class module;
|
class module;
|
||||||
|
class value;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace triton::codegen {
|
namespace triton::codegen {
|
||||||
@@ -13,9 +16,11 @@ class target;
|
|||||||
namespace triton::codegen::transform {
|
namespace triton::codegen::transform {
|
||||||
class prefetch {
|
class prefetch {
|
||||||
target* tgt_;
|
target* tgt_;
|
||||||
|
std::set<ir::value*> prefetched_vals_;
|
||||||
public:
|
public:
|
||||||
prefetch(target *tgt) : tgt_(tgt) {}
|
prefetch(target *tgt) : tgt_(tgt) {}
|
||||||
void run(ir::module &module);
|
void run(ir::module &module);
|
||||||
|
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -832,6 +832,7 @@ public:
|
|||||||
static async_wait_inst* create(context &ctx, int N,
|
static async_wait_inst* create(context &ctx, int N,
|
||||||
const std::string &name = "", instruction *next = nullptr);
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
int get_N() { return N_; }
|
int get_N() { return N_; }
|
||||||
|
void set_N(int n) { N_ = n; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int N_;
|
int N_;
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#ifndef _TRITON_IR_PRINT_H_
|
#ifndef _TRITON_IR_PRINT_H_
|
||||||
#define _TRITON_IR_PRINT_H_
|
#define _TRITON_IR_PRINT_H_
|
||||||
|
|
||||||
@@ -9,8 +7,14 @@ namespace triton{
|
|||||||
namespace ir{
|
namespace ir{
|
||||||
|
|
||||||
class module;
|
class module;
|
||||||
|
class function;
|
||||||
|
class basic_block;
|
||||||
|
class instruction;
|
||||||
|
|
||||||
void print(module &mod, std::ostream& os);
|
void print(module &mod, std::ostream& os);
|
||||||
|
void print(function &func, std::ostream& os);
|
||||||
|
void print(basic_block &bb, std::ostream& os);
|
||||||
|
void print(instruction &instr, std::ostream& os);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -7,6 +7,7 @@
|
|||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/utils.h"
|
#include "triton/ir/utils.h"
|
||||||
|
// #include "triton/ir/type.h"
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace codegen{
|
namespace codegen{
|
||||||
@@ -273,6 +274,81 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
|||||||
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
res.reset(new double_buffer_info_t{value_1, value_0, phi});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool is_smem(ir::value* v) {
|
||||||
|
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||||
|
dynamic_cast<ir::masked_load_async_inst*>(v))
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// param:
|
||||||
|
/// value_1: next_value
|
||||||
|
static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
|
||||||
|
std::vector<ir::value*>& values_0, ir::value*& value_1) {
|
||||||
|
ir::value* next = phi;
|
||||||
|
while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
|
||||||
|
// smem from previous bb & phi/smem from current bb
|
||||||
|
ir::value* c0 = cphi->get_incoming_value(0);
|
||||||
|
ir::value* c1 = cphi->get_incoming_value(1);
|
||||||
|
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
|
||||||
|
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
|
||||||
|
|
||||||
|
if (is_smem(c0)) {
|
||||||
|
assert(cbb0 == bb0);
|
||||||
|
values_0.push_back(c0);
|
||||||
|
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
|
||||||
|
next = phi1;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (is_smem(c1)) {
|
||||||
|
value_1 = c1;
|
||||||
|
assert(cbb1 == bb1);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
|
||||||
|
auto* phi = dynamic_cast<ir::phi_node*>(v);
|
||||||
|
// if the phi node is nested
|
||||||
|
if (!phi)
|
||||||
|
return;
|
||||||
|
|
||||||
|
ir::basic_block *bb0 = phi->get_incoming_block(0);
|
||||||
|
ir::basic_block *bb1 = phi->get_incoming_block(1);
|
||||||
|
|
||||||
|
std::vector<ir::value*> values_0;
|
||||||
|
ir::value* value_1;
|
||||||
|
|
||||||
|
if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
|
||||||
|
return;
|
||||||
|
|
||||||
|
// double-buffer is a special case
|
||||||
|
if (values_0.size() == 1)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// compute original values_0 input order
|
||||||
|
std::map<ir::value*, int> order;
|
||||||
|
int idx = 0;
|
||||||
|
for (ir::instruction* instr : *bb0) {
|
||||||
|
if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
|
||||||
|
order[static_cast<ir::value*>(instr)] = idx++;
|
||||||
|
}
|
||||||
|
assert(order.size() == values_0.size() && "order size incorrect");
|
||||||
|
|
||||||
|
int curr_stages = values_0.size() + 1;
|
||||||
|
if (curr_stages > prev_stages) {
|
||||||
|
res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
|
||||||
|
prev_stages = curr_stages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
shared_layout::shared_layout(data_layout *arg,
|
shared_layout::shared_layout(data_layout *arg,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
@@ -284,9 +360,15 @@ shared_layout::shared_layout(data_layout *arg,
|
|||||||
size_ = 0;
|
size_ = 0;
|
||||||
arg_layout_ = arg;
|
arg_layout_ = arg;
|
||||||
|
|
||||||
|
// N-stage buffering
|
||||||
|
int prev_stages = 0;
|
||||||
|
for (ir::value *v : values)
|
||||||
|
extract_N_bufferable(v, N_buffer_, prev_stages);
|
||||||
|
|
||||||
// double-buffering
|
// double-buffering
|
||||||
for(ir::value *v: values)
|
if (!N_buffer_)
|
||||||
extract_double_bufferable(v, double_buffer_);
|
for(ir::value *v: values)
|
||||||
|
extract_double_bufferable(v, double_buffer_);
|
||||||
|
|
||||||
// order
|
// order
|
||||||
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
|
||||||
@@ -311,8 +393,22 @@ shared_layout::shared_layout(data_layout *arg,
|
|||||||
size_ *= s;
|
size_ *= s;
|
||||||
if(double_buffer_)
|
if(double_buffer_)
|
||||||
size_ *= 2;
|
size_ *= 2;
|
||||||
|
if (N_buffer_) {
|
||||||
|
size_ *= (N_buffer_->firsts.size() + 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int shared_layout::get_num_stages() const {
|
||||||
|
if (double_buffer_)
|
||||||
|
return 2;
|
||||||
|
if (N_buffer_)
|
||||||
|
return N_buffer_->firsts.size() + 1;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t shared_layout::get_per_stage_elements() const {
|
||||||
|
return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
|
||||||
|
}
|
||||||
|
|
||||||
/* -------------------------------- *
|
/* -------------------------------- *
|
||||||
* ---- Layouts Inference Pass ---- *
|
* ---- Layouts Inference Pass ---- *
|
||||||
@@ -403,7 +499,6 @@ void layouts::run(ir::module &mod) {
|
|||||||
for(const auto& x: values_)
|
for(const auto& x: values_)
|
||||||
create(x.first, x.second);
|
create(x.first, x.second);
|
||||||
|
|
||||||
|
|
||||||
// create temporaries
|
// create temporaries
|
||||||
size_t id = values_.size();
|
size_t id = values_.size();
|
||||||
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||||
|
@@ -26,7 +26,7 @@ namespace codegen {
|
|||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// There should be a proper pass manager there!
|
// There should be a proper pass manager there!
|
||||||
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages,
|
||||||
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
|
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
@@ -39,26 +39,27 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
|
|||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::transform::cts cts(cts_use_async);
|
codegen::transform::cts cts(cts_use_async);
|
||||||
codegen::transform::pipeline pipeline(cts_use_async);
|
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
|
||||||
codegen::transform::disassociate disassociate;
|
codegen::transform::disassociate disassociate;
|
||||||
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
|
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
|
||||||
codegen::analysis::liveness liveness(&layouts);
|
codegen::analysis::liveness liveness(&layouts);
|
||||||
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
codegen::analysis::swizzle swizzle(&layouts, target.get());
|
||||||
codegen::analysis::allocation allocation(&liveness);
|
codegen::analysis::allocation allocation(&liveness);
|
||||||
codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get());
|
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole(target.get(), &layouts);
|
codegen::transform::peephole peephole(target.get(), &layouts);
|
||||||
// codegen::transform::reassociate reassociate;
|
// codegen::transform::reassociate reassociate;
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::transform::prefetch prefetch_s(target.get());
|
codegen::transform::prefetch prefetch_s(target.get());
|
||||||
|
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
|
// ir::print(ir, std::cout);
|
||||||
pipeline.run(ir);
|
pipeline.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
//ir::print(ir, std::cout);
|
// ir::print(ir, std::cout);
|
||||||
disassociate.run(ir);
|
disassociate.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
|
@@ -212,18 +212,41 @@ void generator::visit_value(ir::value* v) {
|
|||||||
return;
|
return;
|
||||||
if(v->get_type()->is_block_ty()){
|
if(v->get_type()->is_block_ty()){
|
||||||
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
|
if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){
|
||||||
auto double_buffer = layout->get_double_buffer();
|
analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer();
|
||||||
|
analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer();
|
||||||
|
|
||||||
// offset
|
// offset
|
||||||
Value *offset = nullptr;
|
Value *offset = nullptr;
|
||||||
if(double_buffer && v == double_buffer->phi)
|
|
||||||
offset = shared_off_[layout];
|
|
||||||
// base pointer
|
// base pointer
|
||||||
Value *ptr = shared_ptr_[layout];
|
Value *ptr = shared_ptr_[layout];
|
||||||
if(double_buffer && v == double_buffer->latch)
|
|
||||||
ptr = shared_next_ptr_[layout];
|
if (n_buffer) {
|
||||||
else if(double_buffer && v == double_buffer->first)
|
// ptr = base (shared_ptr_[layout]) + smem_idx * size
|
||||||
ptr = shared_pre_ptr_[layout];
|
// read_smem_idx
|
||||||
|
if (v == n_buffer->phi) {
|
||||||
|
ptr = shared_ptr_[layout];
|
||||||
|
}
|
||||||
|
// write_smem_idx
|
||||||
|
if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) {
|
||||||
|
int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v);
|
||||||
|
int elements = write_smem_idx * layout->get_per_stage_elements();
|
||||||
|
ptr = gep(shared_pre_ptr_[layout], i32(elements));
|
||||||
|
} else if (v == n_buffer->latch) {
|
||||||
|
Value* write_smem_idx = write_smem_idx_[layout];
|
||||||
|
Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements()));
|
||||||
|
ptr = gep(shared_pre_ptr_[layout], elements);
|
||||||
|
}
|
||||||
|
} else if (double_buffer) {
|
||||||
|
if(v == double_buffer->phi)
|
||||||
|
offset = shared_off_[layout];
|
||||||
|
if(v == double_buffer->latch)
|
||||||
|
ptr = shared_next_ptr_[layout];
|
||||||
|
else if(v == double_buffer->first)
|
||||||
|
ptr = shared_pre_ptr_[layout];
|
||||||
|
} // else do nothing
|
||||||
|
// what visit_dot & vist_cts & ... see
|
||||||
shmems_[v] = ptr;
|
shmems_[v] = ptr;
|
||||||
|
// now only latches have offset (PHINode), only used by finalize_share_layout()
|
||||||
shoffs_[v] = offset;
|
shoffs_[v] = offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1223,24 +1246,21 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
|||||||
* \brief Code Generation for `mma.16816` (A100)
|
* \brief Code Generation for `mma.16816` (A100)
|
||||||
*/
|
*/
|
||||||
//TODO: clean-up
|
//TODO: clean-up
|
||||||
void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) {
|
||||||
const auto& shapes = dot->get_type()->get_block_shapes();
|
const std::vector<unsigned>& shapes = C->get_type()->get_block_shapes();
|
||||||
|
|
||||||
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
std::map<std::vector<Value*>, std::vector<Value*>> fcs;
|
||||||
|
for(indices_t idx: idxs_.at(C)){
|
||||||
for(indices_t idx: idxs_.at(dot)){
|
|
||||||
std::vector<Value*> key(idx.size() - 2);
|
std::vector<Value*> key(idx.size() - 2);
|
||||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||||
fcs[key].push_back(vals_[D][idx]);
|
fcs[key].push_back(vals_[D][idx]);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto shape_a = A->get_type()->get_block_shapes();
|
auto shape_a = A->get_type()->get_block_shapes();
|
||||||
auto shape_b = B->get_type()->get_block_shapes();
|
auto shape_b = B->get_type()->get_block_shapes();
|
||||||
auto ord_a = layouts_->get(A)->get_order();
|
auto ord_a = layouts_->get(A)->get_order();
|
||||||
auto ord_b = layouts_->get(B)->get_order();
|
auto ord_b = layouts_->get(B)->get_order();
|
||||||
analysis::mma_layout* layout = layouts_->get(dot)->to_mma();
|
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
|
||||||
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0));
|
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
|
||||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1));
|
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||||
bool is_a_row = ord_a[0] == 1;
|
bool is_a_row = ord_a[0] == 1;
|
||||||
bool is_b_row = ord_b[0] == 1;
|
bool is_b_row = ord_b[0] == 1;
|
||||||
std::string a_trans = is_a_row ? "" : ".trans";
|
std::string a_trans = is_a_row ? "" : ".trans";
|
||||||
@@ -1264,8 +1284,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
|||||||
int vec_a = 8;
|
int vec_a = 8;
|
||||||
int vec_b = 8;
|
int vec_b = 8;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Type *fp32_ty = f32_ty;
|
Type *fp32_ty = f32_ty;
|
||||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||||
@@ -1276,7 +1294,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
|||||||
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
|
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
|
||||||
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
std::map<std::pair<unsigned, unsigned>, Value*> hb;
|
||||||
|
|
||||||
|
|
||||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||||
@@ -1339,66 +1356,167 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir
|
|||||||
"{$0, $1, $2, $3}, "
|
"{$0, $1, $2, $3}, "
|
||||||
"{$4, $5, $6, $7}, "
|
"{$4, $5, $6, $7}, "
|
||||||
"{$8, $9}, "
|
"{$8, $9}, "
|
||||||
"{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false);
|
"{$10, $11, $12, $13};",
|
||||||
|
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||||
|
|
||||||
unsigned num_rep_0 = shapes[0] / layout->spt(0);
|
unsigned num_rep_0 = shapes[0] / layout->spt(0);
|
||||||
unsigned num_rep_1 = shapes[1] / layout->spt(1);
|
unsigned num_rep_1 = shapes[1] / layout->spt(1);
|
||||||
for(unsigned K = 0; K < NK; K += 16)
|
|
||||||
for(unsigned m = 0; m < num_rep_0; m++)
|
// create mma & unpack result
|
||||||
for(unsigned n = 0; n < num_rep_1; n++){
|
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
|
||||||
if(ha.find({m, K}) == ha.end()){
|
unsigned cols_per_thread = num_rep_0 * 2;
|
||||||
Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a];
|
std::vector<size_t> idx = {
|
||||||
|
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
||||||
|
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
||||||
|
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
||||||
|
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
||||||
|
};
|
||||||
|
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
||||||
|
hb[{n, K}], hb[{n, K+8}],
|
||||||
|
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
||||||
|
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
||||||
|
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
||||||
|
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||||
|
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
||||||
|
};
|
||||||
|
|
||||||
|
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||||
|
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||||
|
|
||||||
|
auto register_lds =
|
||||||
|
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
|
||||||
|
if (K <= 8 && is_prefetch) {
|
||||||
|
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||||
|
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
|
||||||
|
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
|
||||||
|
} else
|
||||||
|
vals[{m, K}] = {val0, val1};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto register_lds2 =
|
||||||
|
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
|
||||||
|
if (K <= 8 && is_prefetch) {
|
||||||
|
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||||
|
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
|
||||||
|
} else
|
||||||
|
vals[{m, K}] = val;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
|
||||||
|
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
|
||||||
|
Value* ptra;
|
||||||
|
if(K == 0 && is_prefetch){
|
||||||
|
if(inc == 0)
|
||||||
|
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
|
||||||
|
else
|
||||||
|
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
ptra = ptrs_a[offidx];
|
||||||
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
|
||||||
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
|
||||||
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
|
||||||
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false);
|
"{$0, $1, $2, $3}, [$4 + " +
|
||||||
|
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
|
||||||
|
"=r,=r,=r,=r,r", true);
|
||||||
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
|
||||||
|
if(K == 0 && inc == 1 && is_prefetch)
|
||||||
|
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
|
||||||
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
|
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
|
||||||
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
|
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
|
||||||
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
|
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
|
||||||
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
|
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
|
||||||
ha[{m, K}] = std::make_pair(ha0, ha1);
|
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
|
||||||
ha[{m, K+8}] = std::make_pair(ha2, ha3);
|
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
|
||||||
}
|
};
|
||||||
if(hb.find({n, K})==hb.end()){
|
|
||||||
Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b];
|
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
|
||||||
|
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
|
||||||
|
Value* ptrb;
|
||||||
|
if(K == 0 && is_prefetch){
|
||||||
|
if(inc == 0)
|
||||||
|
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
|
||||||
|
else
|
||||||
|
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
ptrb = ptrs_b[offidx];
|
||||||
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
|
||||||
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
|
||||||
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
|
||||||
"{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false);
|
"{$0, $1, $2, $3}, [$4 + " +
|
||||||
|
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
|
||||||
|
"=r,=r,=r,=r,r", true);
|
||||||
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
|
||||||
|
if(K == 0 && inc == 1 && is_prefetch)
|
||||||
|
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
|
||||||
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
|
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
|
||||||
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
|
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
|
||||||
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
|
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
|
||||||
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
|
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
|
||||||
hb[{n, K}] = hb0;
|
register_lds2(hb, n, K, inc, hb0, is_prefetch);
|
||||||
hb[{n+1, K}] = hb2;
|
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
|
||||||
hb[{n, K+8}] = hb1;
|
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
|
||||||
hb[{n+1, K+8}] = hb3;
|
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
|
||||||
}
|
};
|
||||||
unsigned cols_per_thread = num_rep_0 * 2;
|
|
||||||
std::vector<size_t> idx = {
|
|
||||||
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
|
|
||||||
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
|
|
||||||
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
|
|
||||||
(m*2 + 1) + (n*2 + 1)*cols_per_thread
|
|
||||||
};
|
|
||||||
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
|
|
||||||
hb[{n, K}], hb[{n, K+8}],
|
|
||||||
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
|
|
||||||
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
|
|
||||||
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
|
|
||||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
|
||||||
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if (C->is_prefetched()) {
|
||||||
|
// create phis
|
||||||
|
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||||
|
for(unsigned m = 0; m < num_rep_0; m++){
|
||||||
|
ha[{m, 0}].first = phi(fp16x2_ty, 2);
|
||||||
|
ha[{m, 0}].second = phi(fp16x2_ty, 2);
|
||||||
|
ha[{m, 8}].first = phi(fp16x2_ty, 2);
|
||||||
|
ha[{m, 8}].second = phi(fp16x2_ty, 2);
|
||||||
|
}
|
||||||
|
for(unsigned n = 0; n < num_rep_1; n+=2){
|
||||||
|
hb[{n, 0}] = phi(fp16x2_ty, 2);
|
||||||
|
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
|
||||||
|
hb[{n, 8}] = phi(fp16x2_ty, 2);
|
||||||
|
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
|
||||||
|
}
|
||||||
|
// insert prefetched lds at the end of loop header
|
||||||
|
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
|
||||||
|
for(unsigned m = 0; m < num_rep_0; m++)
|
||||||
|
load_a(m, 0, 0, true);
|
||||||
|
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||||
|
load_b(n, 0, 0, true);
|
||||||
|
// update accumulators
|
||||||
|
builder_->SetInsertPoint(CurrBB);
|
||||||
|
for(unsigned K = 0; K < NK; K += 16){
|
||||||
|
int NEXTK = (K + 16) % NK;
|
||||||
|
// prefetch A
|
||||||
|
for(unsigned m = 0; m < num_rep_0; m++)
|
||||||
|
load_a(m, NEXTK, 1, true);
|
||||||
|
// prefetch B
|
||||||
|
for(unsigned n = 0; n < num_rep_1; n+=2)
|
||||||
|
load_b(n, NEXTK, 1, true);
|
||||||
|
// tensor core ops
|
||||||
|
for(unsigned m = 0; m < num_rep_0; m++)
|
||||||
|
for(unsigned n = 0; n < num_rep_1; n++){
|
||||||
|
call_mma(m, n, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
for(unsigned K = 0; K < NK; K += 16)
|
||||||
|
for(unsigned m = 0; m < num_rep_0; m++)
|
||||||
|
for(unsigned n = 0; n < num_rep_1; n++){
|
||||||
|
if(ha.find({m, K}) == ha.end())
|
||||||
|
load_a(m, K, 0, false);
|
||||||
|
if(hb.find({n, K})==hb.end())
|
||||||
|
load_b(n, K, 0, false);
|
||||||
|
call_mma(m, n, K);
|
||||||
|
}
|
||||||
|
}
|
||||||
// write back
|
// write back
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
for(indices_t idx: idxs_.at(dot)){
|
for(indices_t idx: idxs_.at(C)){
|
||||||
std::vector<Value*> key(idx.size() - 2);
|
std::vector<Value*> key(idx.size() - 2);
|
||||||
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
std::copy(idx.begin() + 2, idx.end(), key.begin());
|
||||||
if(i >= fcs.at(key).size())
|
if(i >= fcs.at(key).size())
|
||||||
i = 0;
|
i = 0;
|
||||||
vals_[dot][idx] = fcs.at(key)[i++];
|
vals_[C][idx] = fcs.at(key)[i++];
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2252,8 +2370,35 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) {
|
|||||||
void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
||||||
Type* ty = cvt(layout->get_type());
|
Type* ty = cvt(layout->get_type());
|
||||||
PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
|
PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace());
|
||||||
// double-buffered
|
if (layout->get_N_buffer()) {
|
||||||
if(layout->get_double_buffer()) {
|
// create pointers
|
||||||
|
shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout)));
|
||||||
|
shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty);
|
||||||
|
|
||||||
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
|
|
||||||
|
auto info = *layout->get_N_buffer();
|
||||||
|
ir::phi_node *phi = info.phi;
|
||||||
|
BasicBlock *parent = bbs_.at(phi->get_parent());
|
||||||
|
if(parent->empty())
|
||||||
|
builder_->SetInsertPoint(parent);
|
||||||
|
else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) {
|
||||||
|
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||||
|
} else
|
||||||
|
builder_->SetInsertPoint(parent);
|
||||||
|
|
||||||
|
// create smem_idx
|
||||||
|
read_smem_idx_[layout] = phi(i32_ty, 2);
|
||||||
|
write_smem_idx_[layout] = phi(i32_ty, 2);
|
||||||
|
|
||||||
|
// create pointers
|
||||||
|
// ptr of the current iteration
|
||||||
|
shared_ptr_[layout] = phi(ptr_ty, 2);
|
||||||
|
// ptr of the next iteration
|
||||||
|
shared_next_ptr_[layout] = phi(ptr_ty, 2);
|
||||||
|
|
||||||
|
builder_->SetInsertPoint(current);
|
||||||
|
} else if(layout->get_double_buffer()) {
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
auto info = *layout->get_double_buffer();
|
auto info = *layout->get_double_buffer();
|
||||||
ir::phi_node *phi = info.phi;
|
ir::phi_node *phi = info.phi;
|
||||||
@@ -2269,8 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) {
|
|||||||
shared_off_[layout] = phi(i32_ty, 2);
|
shared_off_[layout] = phi(i32_ty, 2);
|
||||||
shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
|
shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr");
|
||||||
builder_->SetInsertPoint(current);
|
builder_->SetInsertPoint(current);
|
||||||
}
|
} else{
|
||||||
else{
|
|
||||||
size_t offset = alloc_->offset(layout);
|
size_t offset = alloc_->offset(layout);
|
||||||
shared_ptr_[layout] = gep(shmem_, i32(offset));
|
shared_ptr_[layout] = gep(shmem_, i32(offset));
|
||||||
shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
|
shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty);
|
||||||
@@ -2354,7 +2498,67 @@ void generator::init_idx(ir::value *v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
void generator::finalize_shared_layout(analysis::shared_layout *shared) {
|
||||||
if(shared->get_double_buffer()) {
|
if (auto n_buffer = shared->get_N_buffer()) {
|
||||||
|
// if (*_smem_idx == #stages-1) {
|
||||||
|
// *_smem_idx = 0;
|
||||||
|
// } else *_smem_idx++;
|
||||||
|
auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) {
|
||||||
|
// insert point
|
||||||
|
Value *idx = smem_idx[shared];
|
||||||
|
builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator());
|
||||||
|
Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1));
|
||||||
|
PHINode *_ret = phi(i32_ty, 2);
|
||||||
|
Instruction *then_term = nullptr;
|
||||||
|
Instruction *else_term = nullptr;
|
||||||
|
Instruction *dummy = builder_->CreateRet(nullptr);
|
||||||
|
llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr);
|
||||||
|
dummy->removeFromParent();
|
||||||
|
builder_->SetInsertPoint(then_term);
|
||||||
|
Value *zero_smem_idx = i32(0);
|
||||||
|
builder_->SetInsertPoint(else_term);
|
||||||
|
Value *inc_smem_idx = add(idx, i32(1));
|
||||||
|
builder_->SetInsertPoint(_ret->getParent());
|
||||||
|
_ret->addIncoming(zero_smem_idx, then_term->getParent());
|
||||||
|
_ret->addIncoming(inc_smem_idx, else_term->getParent());
|
||||||
|
// update ir::bb -> llvm::bb mapping
|
||||||
|
bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock();
|
||||||
|
// idx = init_stage;
|
||||||
|
// loop: ...
|
||||||
|
if (auto idx_phi = llvm::dyn_cast<PHINode>(smem_idx[shared])) {
|
||||||
|
idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0)));
|
||||||
|
idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1)));
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("Should be PHINode");
|
||||||
|
};
|
||||||
|
|
||||||
|
// read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2
|
||||||
|
finalize_smem_idx(read_smem_idx_, 2);
|
||||||
|
finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1);
|
||||||
|
|
||||||
|
// finalize pointers
|
||||||
|
ir::phi_node *pn = n_buffer->phi;
|
||||||
|
BasicBlock *header = bbs_.at(pn->get_incoming_block(0));
|
||||||
|
BasicBlock *loop = bbs_.at(pn->get_incoming_block(1));
|
||||||
|
// %curr_ptr = phi %shared_pre_ptr, %next_ptr
|
||||||
|
// %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size))
|
||||||
|
if (auto curr_ptr = dyn_cast<PHINode>(shared_ptr_[shared])) {
|
||||||
|
curr_ptr->addIncoming(shared_pre_ptr_[shared], header);
|
||||||
|
curr_ptr->addIncoming(shared_next_ptr_[shared], loop);
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("Should be PHINode");
|
||||||
|
|
||||||
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
|
builder_->SetInsertPoint(header->getTerminator());
|
||||||
|
Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements()));
|
||||||
|
builder_->SetInsertPoint(current->getTerminator());
|
||||||
|
|
||||||
|
assert(isa<PHINode>(shared_next_ptr_[shared]));
|
||||||
|
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header);
|
||||||
|
|
||||||
|
Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements()));
|
||||||
|
Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset);
|
||||||
|
static_cast<PHINode*>(shared_next_ptr_[shared])->addIncoming(next_ptr, loop);
|
||||||
|
} else if(shared->get_double_buffer()) {
|
||||||
auto info = *shared->get_double_buffer();
|
auto info = *shared->get_double_buffer();
|
||||||
ir::phi_node *phi = info.phi;
|
ir::phi_node *phi = info.phi;
|
||||||
PHINode *ptr = (PHINode*)shmems_[phi];
|
PHINode *ptr = (PHINode*)shmems_[phi];
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/codegen/analysis/allocation.h"
|
#include "triton/codegen/analysis/allocation.h"
|
||||||
#include "triton/codegen/transform/membar.h"
|
#include "triton/codegen/transform/membar.h"
|
||||||
|
#include "triton/codegen/transform/prefetch.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
@@ -20,9 +21,14 @@ namespace transform{
|
|||||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||||
analysis::double_buffer_info_t* info = layout->get_double_buffer();
|
if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
|
||||||
if(info)
|
|
||||||
return group_of(info->first, async_write);
|
return group_of(info->first, async_write);
|
||||||
|
else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
|
||||||
|
if (v == info->phi)
|
||||||
|
return group_of(info->firsts[0], async_write);
|
||||||
|
else // prefetched value
|
||||||
|
return group_of(info->firsts[1], async_write);
|
||||||
|
}
|
||||||
std::vector<int> groups(phi->get_num_operands());
|
std::vector<int> groups(phi->get_num_operands());
|
||||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||||
return *std::max_element(groups.begin(), groups.end());
|
return *std::max_element(groups.begin(), groups.end());
|
||||||
@@ -69,12 +75,31 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool membar::check_safe_war(ir::instruction* i) {
|
||||||
|
bool is_i_shared_block = i->get_type()->is_block_ty() &&
|
||||||
|
layouts_->get(i)->to_shared();
|
||||||
|
bool is_i_double_buffered = is_i_shared_block &&
|
||||||
|
layouts_->get(i)->to_shared()->get_double_buffer();
|
||||||
|
bool is_i_n_buffered = is_i_shared_block &&
|
||||||
|
layouts_->get(i)->to_shared()->get_N_buffer();
|
||||||
|
|
||||||
|
if (is_i_double_buffered || is_i_n_buffered) {
|
||||||
|
// with async copy & prefetch_s disabled, WARs are not safe
|
||||||
|
if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
|
||||||
|
return false;
|
||||||
|
else
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void membar::transfer(ir::basic_block *block,
|
void membar::transfer(ir::basic_block *block,
|
||||||
val_vec_t& async_write,
|
val_vec_t& async_write,
|
||||||
val_set_t& sync_write,
|
val_set_t& sync_write,
|
||||||
val_set_t& sync_read,
|
val_set_t& sync_read,
|
||||||
std::set<ir::value*>& safe_war,
|
std::set<ir::value*>& safe_war,
|
||||||
bool& inserted, ir::builder& builder) {
|
bool& inserted, ir::builder& builder) {
|
||||||
|
std::vector<ir::async_wait_inst*> async_waits;
|
||||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||||
for(ir::instruction *i: instructions){
|
for(ir::instruction *i: instructions){
|
||||||
if(dynamic_cast<ir::phi_node*>(i))
|
if(dynamic_cast<ir::phi_node*>(i))
|
||||||
@@ -105,18 +130,14 @@ void membar::transfer(ir::basic_block *block,
|
|||||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||||
inserted = true;
|
inserted = true;
|
||||||
|
async_waits.push_back(async_wait);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// RAW, WAR
|
// RAW, WAR
|
||||||
bool is_i_double_buffered = i->get_type()->is_block_ty() &&
|
bool is_safe_war = check_safe_war(i);
|
||||||
layouts_->get(i)->to_shared() &&
|
|
||||||
layouts_->get(i)->to_shared()->get_double_buffer();
|
|
||||||
// WAR barrier is not required when data is double-buffered
|
// WAR barrier is not required when data is double-buffered
|
||||||
// TODO: how about other patterns, like WWAR?
|
if(!intersect_with(read, sync_write).empty() ||
|
||||||
if(!intersect_with(read, sync_write).empty() ||
|
(!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
|
||||||
(!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) ||
|
|
||||||
// force WAR barrier on A100
|
|
||||||
(!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){
|
|
||||||
builder.set_insert_point(i);
|
builder.set_insert_point(i);
|
||||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||||
inserted = true;
|
inserted = true;
|
||||||
@@ -132,7 +153,41 @@ void membar::transfer(ir::basic_block *block,
|
|||||||
sync_read.clear();
|
sync_read.clear();
|
||||||
}
|
}
|
||||||
sync_read.insert(read.begin(), read.end());
|
sync_read.insert(read.begin(), read.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// coalesce barriers
|
||||||
|
// fixme: to support more general cases
|
||||||
|
if (async_waits.size() == 2) {
|
||||||
|
// (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
|
||||||
|
for (int idx=0; idx<async_waits.size()-1; ++idx) {
|
||||||
|
ir::async_wait_inst *first_async_wait = async_waits[idx];
|
||||||
|
std::vector<ir::instruction*> to_erase;
|
||||||
|
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||||
|
for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
|
||||||
|
ir::instruction *i = *iter;
|
||||||
|
if (static_cast<ir::instruction*>(first_async_wait) == i) {
|
||||||
|
// peak next 5 instructions
|
||||||
|
auto peak_iter = std::next(iter);
|
||||||
|
if (std::distance(peak_iter, instructions.end()) >= 5) {
|
||||||
|
auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||||
|
auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
|
||||||
|
auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
|
||||||
|
auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
|
||||||
|
auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
|
||||||
|
if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
|
||||||
|
int first_n = first_async_wait->get_N();
|
||||||
|
int second_n = second_async_wait->get_N();
|
||||||
|
to_erase.push_back(second_async_wait);
|
||||||
|
to_erase.push_back(second_bar);
|
||||||
|
first_async_wait->set_N(second_n);
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
break;
|
||||||
|
for (ir::instruction *i : to_erase)
|
||||||
|
block->erase(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +199,7 @@ void membar::run(ir::module &mod) {
|
|||||||
std::set<ir::value*> safe_war;
|
std::set<ir::value*> safe_war;
|
||||||
for(const auto& x: layouts_->get_all()){
|
for(const auto& x: layouts_->get_all()){
|
||||||
analysis::shared_layout* layout = x.second->to_shared();
|
analysis::shared_layout* layout = x.second->to_shared();
|
||||||
if(!layout || !layout->get_double_buffer())
|
if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
|
||||||
continue;
|
continue;
|
||||||
for(ir::value *v: layout->get_values())
|
for(ir::value *v: layout->get_values())
|
||||||
if(v != layout->get_double_buffer()->phi){
|
if(v != layout->get_double_buffer()->phi){
|
||||||
@@ -153,7 +208,6 @@ void membar::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
// TODO: (dyan) we need DominatorTree here.
|
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||||
|
@@ -23,6 +23,60 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
|
|||||||
recursive_deps(u, block, ret);
|
recursive_deps(u, block, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// assume incoming block is 1
|
||||||
|
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
|
||||||
|
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||||
|
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||||
|
if(!i)
|
||||||
|
return v;
|
||||||
|
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
|
||||||
|
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
|
||||||
|
throw std::runtime_error("Don't have that phi node\n");
|
||||||
|
return prev_phi_vals.at(phi);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ir::value*> new_ops;
|
||||||
|
for(ir::value* op: i->ops()){
|
||||||
|
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
|
||||||
|
}
|
||||||
|
ir::instruction* ret = i->clone();
|
||||||
|
for(size_t k = 0; k < new_ops.size(); k++)
|
||||||
|
ret->set_operand(k, new_ops[k]);
|
||||||
|
builder.insert(ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
|
||||||
|
auto instr = dynamic_cast<ir::instruction*>(cond);
|
||||||
|
for (auto op : instr->ops()) {
|
||||||
|
if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
|
||||||
|
phis.insert(phi_op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (dynamic_cast<ir::instruction*>(op))
|
||||||
|
get_induction_vars(op, phis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns phi_val if sees a phi node
|
||||||
|
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
|
||||||
|
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||||
|
if(!i)
|
||||||
|
return v;
|
||||||
|
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
|
||||||
|
return phi_val;
|
||||||
|
|
||||||
|
std::vector<ir::value*> new_ops;
|
||||||
|
for(ir::value* op: i->ops()){
|
||||||
|
new_ops.push_back(rematerialize_val(builder, op, phi_val));
|
||||||
|
}
|
||||||
|
ir::instruction* ret = i->clone();
|
||||||
|
for(size_t k = 0; k < new_ops.size(); k++)
|
||||||
|
ret->set_operand(k, new_ops[k]);
|
||||||
|
builder.insert(ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
||||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||||
if(!i)
|
if(!i)
|
||||||
@@ -41,6 +95,28 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// moving the prev phi vals to the next iteration
|
||||||
|
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
|
||||||
|
for (auto& [phi, val] : prev_phi_vals) {
|
||||||
|
// TODO: handling nested phis
|
||||||
|
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
|
||||||
|
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
|
||||||
|
for (auto& [phi, val] : load_ivs) {
|
||||||
|
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
|
||||||
|
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
|
||||||
|
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
|
||||||
|
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
|
||||||
|
// cache next_k (to be used by next_mask)
|
||||||
|
next_load_ivs[phi] = next_k;
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("must be phi");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void pipeline::run(ir::module &mod) {
|
void pipeline::run(ir::module &mod) {
|
||||||
// *Very* conservative heuristics for pre-fetching.
|
// *Very* conservative heuristics for pre-fetching.
|
||||||
// A load instruction can be pipelined if:
|
// A load instruction can be pipelined if:
|
||||||
@@ -60,6 +136,8 @@ void pipeline::run(ir::module &mod) {
|
|||||||
// do the pipelining
|
// do the pipelining
|
||||||
std::vector<ir::phi_node*> new_loads;
|
std::vector<ir::phi_node*> new_loads;
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
|
const int num_stages = num_stages_;
|
||||||
|
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
|
||||||
for(auto info: to_pipeline){
|
for(auto info: to_pipeline){
|
||||||
ir::load_inst* load = info.first;
|
ir::load_inst* load = info.first;
|
||||||
ir::phi_node* ptr = info.second;
|
ir::phi_node* ptr = info.second;
|
||||||
@@ -70,40 +148,155 @@ void pipeline::run(ir::module &mod) {
|
|||||||
assert(block_br);
|
assert(block_br);
|
||||||
assert(header_br);
|
assert(header_br);
|
||||||
ir::type* ty = load->get_type();
|
ir::type* ty = load->get_type();
|
||||||
// pre-fetch first iteration
|
// multi-stage pipe
|
||||||
builder.set_insert_point(header->get_inst_list().back());
|
if (has_copy_async_ && num_stages > 2) {
|
||||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
ir::value* header_cond = header_br->get_cond();
|
||||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
ir::value* block_cond = block_br->get_cond();
|
||||||
ir::value* false_value;
|
// 1. collect induction variables
|
||||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
std::set<ir::phi_node*> induction_vars;
|
||||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
get_induction_vars(block_cond, induction_vars);
|
||||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
|
||||||
first_mask = builder.create_and(first_mask, remat_mask);
|
std::vector<ir::value*> first_ptrs(num_stages-1);
|
||||||
false_value = remat_false_value;
|
std::vector<ir::value*> first_loads(num_stages-1);
|
||||||
|
std::vector<ir::value*> first_masks(num_stages-1);
|
||||||
|
std::vector<ir::value*> loop_conds(num_stages-1);
|
||||||
|
|
||||||
|
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
|
||||||
|
// initialize prev_phi_vals
|
||||||
|
// note: we assume that ptr & other values only depend on ptr & iv (phis)
|
||||||
|
// TODO: can we just add all phis here?
|
||||||
|
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
|
||||||
|
for (ir::phi_node* iv : induction_vars)
|
||||||
|
prev_phi_vals[iv] = iv->get_value_for_block(header);
|
||||||
|
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
|
||||||
|
|
||||||
|
builder.set_insert_point(header->get_inst_list().back());
|
||||||
|
first_ptrs[0] = ptr->get_value_for_block(header);
|
||||||
|
loop_conds[0] = header_cond;
|
||||||
|
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
|
||||||
|
ir::value* false_value = nullptr;
|
||||||
|
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||||
|
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
|
||||||
|
ir::value* remat_false_value =
|
||||||
|
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||||
|
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
|
||||||
|
false_value = remat_false_value;
|
||||||
|
} else
|
||||||
|
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||||
|
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
|
||||||
|
|
||||||
|
for (int stage = 1; stage < num_stages-1; ++stage) {
|
||||||
|
// mask is the loop condition of the previous iteration
|
||||||
|
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
|
||||||
|
update_prev_phi_vals(builder, prev_phi_vals);
|
||||||
|
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
|
||||||
|
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
|
||||||
|
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||||
|
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
|
||||||
|
ir::value* remat_false_value =
|
||||||
|
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
|
||||||
|
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
|
||||||
|
false_value = remat_false_value;
|
||||||
|
}
|
||||||
|
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new phis for induction variables
|
||||||
|
builder.set_insert_point(block->get_first_non_phi());
|
||||||
|
std::map<ir::phi_node*, ir::value*> load_ivs;
|
||||||
|
std::map<ir::phi_node*, ir::value*> next_load_ivs;
|
||||||
|
for (ir::phi_node* iv : induction_vars) {
|
||||||
|
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
|
||||||
|
pn->add_incoming(prev_phi_vals[iv], header);
|
||||||
|
load_ivs[iv] = pn;
|
||||||
|
}
|
||||||
|
// add incoming for phis & update next_load_ivs
|
||||||
|
finalize_iv_vals(builder, load_ivs, next_load_ivs);
|
||||||
|
|
||||||
|
// pre-fetch next iteration
|
||||||
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
|
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||||
|
ir::value* next_mask = builder.create_splat(
|
||||||
|
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
|
||||||
|
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
|
||||||
|
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
|
||||||
|
// TODO: false may depends on some other phi nodes
|
||||||
|
ir::value* remat_false_value =
|
||||||
|
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
|
||||||
|
next_mask = builder.create_and(next_mask, remat_mask);
|
||||||
|
false_value = remat_false_value;
|
||||||
|
}
|
||||||
|
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||||
|
|
||||||
|
|
||||||
|
// phi node
|
||||||
|
ptr->set_incoming_value(0, first_ptrs.back());
|
||||||
|
builder.set_insert_point(block->get_first_non_phi());
|
||||||
|
// nested phis for load
|
||||||
|
std::vector<ir::phi_node*> new_load_phis(num_stages-1);
|
||||||
|
for (auto& pn : new_load_phis)
|
||||||
|
pn = builder.create_phi(ty, 2);
|
||||||
|
for (int i=0; i<num_stages-2; ++i) {
|
||||||
|
new_load_phis[i]->add_incoming(first_loads[i], header);
|
||||||
|
new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
|
||||||
|
}
|
||||||
|
new_load_phis.back()->add_incoming(first_loads.back(), header);
|
||||||
|
new_load_phis.back()->add_incoming(next_load, block);
|
||||||
|
load->replace_all_uses_with(new_load_phis.front());
|
||||||
|
new_loads.push_back(new_load_phis.back());
|
||||||
|
|
||||||
|
// record first_loads to reorder them
|
||||||
|
preheader_loads.push_back({new_load_phis.front(), first_loads});
|
||||||
|
} else {
|
||||||
|
// pre-fetch first iteration
|
||||||
|
builder.set_insert_point(header->get_inst_list().back());
|
||||||
|
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||||
|
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
|
||||||
|
ir::value* false_value;
|
||||||
|
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||||
|
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
|
||||||
|
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
|
||||||
|
first_mask = builder.create_and(first_mask, remat_mask);
|
||||||
|
false_value = remat_false_value;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
||||||
|
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||||
|
// pre-fetch next iteration
|
||||||
|
builder.set_insert_point(block->get_inst_list().back());
|
||||||
|
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||||
|
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
||||||
|
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||||
|
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
||||||
|
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
||||||
|
next_mask = builder.create_and(next_mask, remat_mask);
|
||||||
|
false_value = remat_false_value;
|
||||||
|
}
|
||||||
|
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||||
|
// phi node
|
||||||
|
builder.set_insert_point(block->get_first_non_phi());
|
||||||
|
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||||
|
new_load->add_incoming(first_load, header);
|
||||||
|
new_load->add_incoming(next_load, block);
|
||||||
|
load->replace_all_uses_with(new_load);
|
||||||
|
new_loads.push_back(new_load);
|
||||||
}
|
}
|
||||||
else
|
|
||||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
|
|
||||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
|
||||||
// pre-fetch next iteration
|
|
||||||
builder.set_insert_point(block->get_inst_list().back());
|
|
||||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
|
||||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
|
|
||||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
|
||||||
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
|
|
||||||
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
|
|
||||||
next_mask = builder.create_and(next_mask, remat_mask);
|
|
||||||
false_value = remat_false_value;
|
|
||||||
}
|
|
||||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
|
||||||
// phi node
|
|
||||||
builder.set_insert_point(block->get_first_non_phi());
|
|
||||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
|
||||||
new_load->add_incoming(first_load, header);
|
|
||||||
new_load->add_incoming(next_load, block);
|
|
||||||
load->replace_all_uses_with(new_load);
|
|
||||||
new_loads.push_back(new_load);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
|
||||||
|
// a0, b0, a1, b1, ...
|
||||||
|
if (!preheader_loads.empty()) {
|
||||||
|
ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
|
||||||
|
builder.set_insert_point(header->get_inst_list().back());
|
||||||
|
for (int i=1; i<num_stages-1; ++i) {
|
||||||
|
for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
|
||||||
|
ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
|
||||||
|
ir::instruction* moved_load = original_load->clone();
|
||||||
|
builder.insert(moved_load);
|
||||||
|
original_load->replace_all_uses_with(moved_load);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// try to move dot_inst after loads
|
// try to move dot_inst after loads
|
||||||
// for better overlap of io and compute
|
// for better overlap of io and compute
|
||||||
|
@@ -25,13 +25,12 @@ static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::in
|
|||||||
}
|
}
|
||||||
|
|
||||||
void prefetch::run(ir::module &mod) {
|
void prefetch::run(ir::module &mod) {
|
||||||
// 1. collect dot that can be prefethced
|
// 1. collect dots that can be prefethced
|
||||||
std::vector<ir::dot_inst*> to_prefetch;
|
std::vector<ir::dot_inst*> to_prefetch;
|
||||||
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
ir::for_each_instruction(mod, [&](ir::instruction *i) {
|
||||||
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
|
||||||
// Now only do prefetching when dot is fp16 & volta/turing
|
// Now only do prefetching when dot is fp16
|
||||||
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID ||
|
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID)
|
||||||
tgt_->as_nvidia()->sm() >= 80)
|
|
||||||
return;
|
return;
|
||||||
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
|
||||||
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
|
||||||
@@ -56,16 +55,31 @@ void prefetch::run(ir::module &mod) {
|
|||||||
|
|
||||||
// 1. in the loop header (first iteration)
|
// 1. in the loop header (first iteration)
|
||||||
builder.set_insert_point(loop_header->get_inst_list().back());
|
builder.set_insert_point(loop_header->get_inst_list().back());
|
||||||
builder.create_barrier();
|
|
||||||
assert(a && b);
|
assert(a && b);
|
||||||
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
|
||||||
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
|
||||||
|
|
||||||
// 2. at the end of the loop body (next iteration)
|
// 2. at the end of the loop body (next iteration)
|
||||||
builder.set_insert_point(loop_body->get_inst_list().back());
|
builder.set_insert_point(loop_body->get_inst_list().back());
|
||||||
builder.create_barrier();
|
|
||||||
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
|
||||||
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
|
||||||
|
|
||||||
|
prefetched_vals_.insert(a->get_incoming_value(0));
|
||||||
|
prefetched_vals_.insert(b->get_incoming_value(0));
|
||||||
|
// nested phis
|
||||||
|
ir::value* next_a = a->get_incoming_value(1);
|
||||||
|
while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
|
||||||
|
prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
|
||||||
|
next_a = next_a_phi->get_incoming_value(1);
|
||||||
|
}
|
||||||
|
prefetched_vals_.insert(next_a);
|
||||||
|
|
||||||
|
ir::value* next_b = b->get_incoming_value(1);
|
||||||
|
while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
|
||||||
|
prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
|
||||||
|
next_b = next_b_phi->get_incoming_value(1);
|
||||||
|
}
|
||||||
|
prefetched_vals_.insert(next_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// move loads to the beginning of the loop
|
// move loads to the beginning of the loop
|
||||||
|
@@ -324,7 +324,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device)
|
|||||||
}
|
}
|
||||||
catch(exception::cuda::invalid_ptx const &){
|
catch(exception::cuda::invalid_ptx const &){
|
||||||
//#ifdef TRITON_LOG_PTX_ERROR
|
//#ifdef TRITON_LOG_PTX_ERROR
|
||||||
std::cout << ptx << std::endl;
|
// std::cout << ptx << std::endl;
|
||||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||||
// exit(1);
|
// exit(1);
|
||||||
//#endif
|
//#endif
|
||||||
|
@@ -77,6 +77,54 @@ void print(module &mod, std::ostream& os) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print(function &fn, std::ostream &os) {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(basic_block &bb, std::ostream &os) {
|
||||||
|
auto const &predecessors = bb.get_predecessors();
|
||||||
|
os << bb.get_name() << ":";
|
||||||
|
if(!predecessors.empty()){
|
||||||
|
os << " ";
|
||||||
|
os << "; preds = ";
|
||||||
|
auto const &predecessors = bb.get_predecessors();
|
||||||
|
for(ir::basic_block *pred: predecessors)
|
||||||
|
os << pred->get_name() << (pred!=predecessors.back()?", ":"");
|
||||||
|
}
|
||||||
|
os << std::endl;
|
||||||
|
for(ir::instruction *inst: bb.get_inst_list()){
|
||||||
|
print(*inst, os);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(instruction &instr, std::ostream &os) {
|
||||||
|
instruction *inst = &instr;
|
||||||
|
os << " ";
|
||||||
|
if(!inst->get_type()->is_void_ty()){
|
||||||
|
os << instr.get_name();
|
||||||
|
os << " = ";
|
||||||
|
}
|
||||||
|
ir::type* type = inst->get_type();
|
||||||
|
os << inst->repr() << " " << type->repr();
|
||||||
|
ir::instruction::ops_t ops = inst->ops();
|
||||||
|
size_t num_ops = inst->get_num_operands();
|
||||||
|
if(num_ops > 0)
|
||||||
|
os << " ";;
|
||||||
|
for(unsigned i = 0; i < num_ops; i++){
|
||||||
|
if(auto *x = dynamic_cast<ir::constant*>(ops[i]))
|
||||||
|
os << x->repr();
|
||||||
|
else
|
||||||
|
os << ops[i]->get_name();
|
||||||
|
os << (i < num_ops - 1?", ":"");
|
||||||
|
}
|
||||||
|
os << ";";
|
||||||
|
// os << " (";
|
||||||
|
// for(ir::user* usr: inst->get_users())
|
||||||
|
// os << get_name(usr, cnt++) << ", " ;
|
||||||
|
// os << " )";
|
||||||
|
os << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -33,7 +33,10 @@ void init_triton_driver(py::module &&m) {
|
|||||||
CUdevice handle;
|
CUdevice handle;
|
||||||
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
drv::dispatch::cuDeviceGet(&handle, dev_id);
|
||||||
return new drv::cu_device(handle, take_ownership);
|
return new drv::cu_device(handle, take_ownership);
|
||||||
}));
|
}))
|
||||||
|
.def("max_shared_memory", [](drv::cu_device *self) {
|
||||||
|
return self->max_shared_memory();
|
||||||
|
});
|
||||||
// host device
|
// host device
|
||||||
py::class_<drv::host_device, drv::device>(m, "host_device")
|
py::class_<drv::host_device, drv::device>(m, "host_device")
|
||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
@@ -75,11 +78,11 @@ void init_triton_driver(py::module &&m) {
|
|||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
void init_triton_codegen(py::module &&m) {
|
||||||
m.def(
|
m.def(
|
||||||
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) {
|
"add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) {
|
||||||
drv::module *mod;
|
drv::module *mod;
|
||||||
drv::kernel *ker;
|
drv::kernel *ker;
|
||||||
size_t shared_mem;
|
size_t shared_mem;
|
||||||
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem);
|
triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem);
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ir::print(ir, ss);
|
ir::print(ir, ss);
|
||||||
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
return std::make_tuple(mod, ker, shared_mem, ss.str());
|
||||||
|
@@ -27,7 +27,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
|
|||||||
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B)
|
||||||
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||||
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||||
rc = op(ra, rb)
|
rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest)
|
||||||
# torch result
|
# torch result
|
||||||
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
|
||||||
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
|
||||||
|
@@ -5,56 +5,69 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE",
|
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE",
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
*[
|
*[
|
||||||
[
|
[
|
||||||
# 1 warp
|
# 1 warp
|
||||||
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# 2 warp
|
# 2 warp
|
||||||
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# 4 warp
|
# 4 warp
|
||||||
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# 8 warp
|
# 8 warp
|
||||||
(128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# # split-k
|
# split-k
|
||||||
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE),
|
||||||
# # variable input
|
# variable input
|
||||||
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE),
|
||||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE),
|
||||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||||
|
],
|
||||||
|
# n-stage
|
||||||
|
*[
|
||||||
|
[
|
||||||
|
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE),
|
||||||
|
# split-k
|
||||||
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE),
|
||||||
|
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE),
|
||||||
|
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# nuke kernel decorators -- will set meta-parameters manually
|
# nuke kernel decorators -- will set meta-parameters manually
|
||||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
||||||
configs = [triton.Config(meta=META, num_warps=NWARP)]
|
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||||
kernel = triton.ops._matmul.kernel
|
kernel = triton.ops._matmul.kernel
|
||||||
decorators = kernel.kernel_decorators
|
decorators = kernel.kernel_decorators
|
||||||
kernel.kernel_decorators = []
|
kernel.kernel_decorators = []
|
||||||
@@ -72,5 +85,5 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE):
|
|||||||
b = b.t() if BT else b
|
b = b.t() if BT else b
|
||||||
# run test
|
# run test
|
||||||
th_c = torch.matmul(a, b)
|
th_c = torch.matmul(a, b)
|
||||||
tt_c = triton.ops.matmul(a, b)
|
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
|
||||||
assert triton.testing.allclose(th_c, tt_c)
|
assert triton.testing.allclose(th_c, tt_c)
|
||||||
|
@@ -407,13 +407,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
|
|
||||||
class Binary:
|
class Binary:
|
||||||
def __init__(self, module, kernel, num_warps, shared_mem, ir_asm):
|
def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm):
|
||||||
# cache ir asm
|
# cache ir asm
|
||||||
self.ir_asm = ir_asm
|
self.ir_asm = ir_asm
|
||||||
self.module = module
|
self.module = module
|
||||||
self.kernel = kernel
|
self.kernel = kernel
|
||||||
self.shared_mem = shared_mem
|
self.shared_mem = shared_mem
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
|
self.num_stages = num_stages
|
||||||
self.sass = None
|
self.sass = None
|
||||||
|
|
||||||
def asm(self, mode):
|
def asm(self, mode):
|
||||||
@@ -447,6 +448,13 @@ class CompilationError(Exception):
|
|||||||
self.message += '\n Error: ' + str(err)
|
self.message += '\n Error: ' + str(err)
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
class OutOfResources(Exception):
|
||||||
|
def __init__(self, required, limit, name):
|
||||||
|
self.message = f'out of resource: {name}'\
|
||||||
|
f'Required: {required}'\
|
||||||
|
f'Hardware limit: {limit}'
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
class Kernel:
|
class Kernel:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -513,7 +521,7 @@ class Kernel:
|
|||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, **meta):
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta):
|
||||||
# explicitly set device
|
# explicitly set device
|
||||||
torch.cuda.set_device(device.index)
|
torch.cuda.set_device(device.index)
|
||||||
# create IR module
|
# create IR module
|
||||||
@@ -535,10 +543,12 @@ class Kernel:
|
|||||||
raise CompilationError(self.fn.src, node, e)
|
raise CompilationError(self.fn.src, node, e)
|
||||||
tt_device = _triton.driver.cu_device(device.index, False)
|
tt_device = _triton.driver.cu_device(device.index, False)
|
||||||
# Compile to machine code
|
# Compile to machine code
|
||||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps)
|
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages)
|
||||||
return Binary(mod, ker, num_warps, shared_mem, ir_asm)
|
if shared_mem > tt_device.max_shared_memory():
|
||||||
|
raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")
|
||||||
|
return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm)
|
||||||
|
|
||||||
def __call__(self, *wargs, grid, num_warps=4, **meta):
|
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta):
|
||||||
# device inference
|
# device inference
|
||||||
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
|
||||||
if len(tensor_idxs) == 0:
|
if len(tensor_idxs) == 0:
|
||||||
@@ -554,12 +564,12 @@ class Kernel:
|
|||||||
attr_key = frozenset(attributes.items())
|
attr_key = frozenset(attributes.items())
|
||||||
meta_key = frozenset(meta.items())
|
meta_key = frozenset(meta.items())
|
||||||
const_key = frozenset(constants.items())
|
const_key = frozenset(constants.items())
|
||||||
key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key)
|
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||||
cache = self.fn.cache
|
cache = self.fn.cache
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
# compile and cache configuration if necessary
|
# compile and cache configuration if necessary
|
||||||
cache[key] = self._compile(
|
cache[key] = self._compile(
|
||||||
*wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta
|
*wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta
|
||||||
)
|
)
|
||||||
# pack arguments
|
# pack arguments
|
||||||
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)])
|
||||||
@@ -585,7 +595,7 @@ class Launcher:
|
|||||||
class Autotuner:
|
class Autotuner:
|
||||||
def __init__(self, kernel, arg_names, configs, key):
|
def __init__(self, kernel, arg_names, configs, key):
|
||||||
if not configs:
|
if not configs:
|
||||||
self.configs = [Config(dict(), num_warps=4)]
|
self.configs = [Config(dict(), num_warps=4, num_stages=2)]
|
||||||
else:
|
else:
|
||||||
self.configs = configs
|
self.configs = configs
|
||||||
self.key_idx = [arg_names.index(k) for k in key]
|
self.key_idx = [arg_names.index(k) for k in key]
|
||||||
@@ -603,7 +613,7 @@ class Autotuner:
|
|||||||
)
|
)
|
||||||
# augment meta-parameters with tunable ones
|
# augment meta-parameters with tunable ones
|
||||||
current = dict(meta, **config.meta)
|
current = dict(meta, **config.meta)
|
||||||
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current)
|
kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
return triton.testing.do_bench(kernel_call)
|
return triton.testing.do_bench(kernel_call)
|
||||||
|
|
||||||
def __call__(self, *args, **meta):
|
def __call__(self, *args, **meta):
|
||||||
@@ -616,7 +626,7 @@ class Autotuner:
|
|||||||
config = self.cache[key]
|
config = self.cache[key]
|
||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta)
|
||||||
|
|
||||||
|
|
||||||
class JITFunction:
|
class JITFunction:
|
||||||
@@ -671,9 +681,10 @@ class JITFunction:
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
def __init__(self, meta, num_warps=4):
|
def __init__(self, meta, num_warps=4, num_stages=2):
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
|
self.num_stages = num_stages
|
||||||
|
|
||||||
|
|
||||||
def autotune(configs, key):
|
def autotune(configs, key):
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
from .code_gen import OutOfResources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton._C.libtriton.cutlass as _cutlass
|
import triton._C.libtriton.cutlass as _cutlass
|
||||||
@@ -8,6 +9,15 @@ except ImportError:
|
|||||||
_cutlass = None
|
_cutlass = None
|
||||||
has_cutlass = False
|
has_cutlass = False
|
||||||
|
|
||||||
|
def catch_oor(kernel, pytest_handle=None):
|
||||||
|
try:
|
||||||
|
res = kernel()
|
||||||
|
except OutOfResources as e:
|
||||||
|
if pytest_handle:
|
||||||
|
pytest_handle.skip(str(e))
|
||||||
|
return None
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def sparsify_tensor(x, mask, block):
|
def sparsify_tensor(x, mask, block):
|
||||||
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
|
||||||
|
Reference in New Issue
Block a user