[code generation] bugfix in single buffering
This commit is contained in:
@@ -21,11 +21,10 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0;){
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, b, C);
|
||||
pa = pa + TK*M;
|
||||
pb = pb + TK*K;
|
||||
k = k - TK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
@@ -164,7 +163,7 @@ int main() {
|
||||
};
|
||||
// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4};
|
||||
|
||||
jit.autotune(src, benchmark);
|
||||
// jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
|
@@ -26,13 +26,14 @@ private:
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
|
||||
private:
|
||||
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
||||
void insert_barrier(ir::instruction *instr, ir::builder &builder);
|
||||
bool intersect(const interval_vec_t &X, interval_t x);
|
||||
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
||||
void add_reference(ir::value *v, interval_vec_t &res);
|
||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder);
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set<ir::instruction *> &insert_loc);
|
||||
|
||||
public:
|
||||
barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {}
|
||||
|
@@ -19,9 +19,11 @@ public:
|
||||
void run(ir::module &mod);
|
||||
// queries
|
||||
bool is_double(ir::value *x);
|
||||
void add_shared(ir::value *v);
|
||||
bool is_shared(ir::value *x);
|
||||
bool is_loop_latch(ir::phi_node *phi, ir::value *terminator);
|
||||
ir::value *get_reference(ir::value *x);
|
||||
void replace(ir::value* before, ir::value *after);
|
||||
|
||||
|
||||
private:
|
||||
|
@@ -58,6 +58,7 @@ public:
|
||||
|
||||
// predecessors
|
||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
||||
void add_predecessor(basic_block* pred);
|
||||
|
||||
// factory functions
|
||||
@@ -68,6 +69,7 @@ private:
|
||||
std::string name_;
|
||||
function *parent_;
|
||||
std::vector<basic_block*> preds_;
|
||||
std::vector<basic_block*> succs_;
|
||||
inst_list_t inst_list_;
|
||||
};
|
||||
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <memory>
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
@@ -54,10 +55,12 @@ public:
|
||||
// generate ptx
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
|
@@ -29,7 +29,7 @@ void allocation::run(){
|
||||
std::vector<ir::value *> J = I;
|
||||
|
||||
triples_map_type H;
|
||||
H.insert({0, segment{0, 100}});
|
||||
H.insert({0, segment{0, 1024}});
|
||||
|
||||
std::vector<ir::value *> V;
|
||||
std::map<ir::value *, unsigned> starts;
|
||||
@@ -116,6 +116,9 @@ void allocation::run(){
|
||||
for(auto &x: offsets_){
|
||||
allocated_size_ = std::max<size_t>(allocated_size_, x.second + get_num_bytes(x.first));
|
||||
}
|
||||
std::cout << "Allocated: " << allocated_size_ << std::endl;
|
||||
for(auto &x: offsets_)
|
||||
std::cout << x.first->get_name() << " " << x.second << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
@@ -62,27 +63,76 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
}
|
||||
}
|
||||
|
||||
void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder) {
|
||||
barriers::interval_vec_t barriers::join(const std::vector<interval_vec_t>& intervals) {
|
||||
barriers::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<barriers::interval_vec_t,
|
||||
barriers::interval_vec_t> barriers::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
if(intersect(not_synced, read)) {
|
||||
not_synced.clear();
|
||||
insert_barrier(i, builder);
|
||||
bool read_while_written = intersect(new_written_to, read);
|
||||
bool written_while_read = intersect(new_read_from, written);
|
||||
// double buffering: write and phi-node read won't intersect
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i) &&
|
||||
buffer_info_->is_double(buffer_info_->get_reference(i)))
|
||||
written_while_read = false;
|
||||
if(read_while_written || written_while_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(not_synced));
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
}
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void barriers::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// find barrier location
|
||||
interval_vec_t not_synced;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
add(block, not_synced, builder);
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
do{
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// written to
|
||||
std::vector<interval_vec_t> pred_written_to;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_written_to.push_back(written_to[pred]);
|
||||
// read from
|
||||
std::vector<interval_vec_t> pred_read_from;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
size_t n_inserted_i = insert_locs.size();
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs){
|
||||
std::cout << i->get_name() << std::endl;
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -21,6 +21,16 @@ bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void buffer_info_pass::replace(ir::value* before, ir::value *after) {
|
||||
shared_.erase(before);
|
||||
shared_.insert(after);
|
||||
if(refs_.find(before) != refs_.end()){
|
||||
ir::value* v = refs_.at(before);
|
||||
refs_.erase(before);
|
||||
refs_.insert({after, v});
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_info_pass::run(ir::module &mod) {
|
||||
// Find which buffers are shared
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
|
@@ -11,30 +11,43 @@ namespace codegen{
|
||||
|
||||
|
||||
// Entry point
|
||||
void liveness::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices_.insert({instr, index});
|
||||
}
|
||||
// Liveness analysis
|
||||
// Creates live intervals
|
||||
for(auto i: indices_){
|
||||
ir::value *v = i.first;
|
||||
if(!info_->is_shared(v) || info_->get_reference(v))
|
||||
continue;
|
||||
unsigned start = i.second;
|
||||
unsigned end = start;
|
||||
for(ir::value *u: v->get_users()){
|
||||
start = std::min(start, indices_.at(u));
|
||||
end = std::max(end, indices_.at(u));
|
||||
}
|
||||
intervals_[v] = segment{start, end};
|
||||
inline bool is_shared(ir::value* v) {
|
||||
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::phi_node*>(v)){
|
||||
bool res = true;
|
||||
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
|
||||
res = res && is_shared(x->get_incoming_value(inc));
|
||||
return res;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void liveness::run(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Assigns index to each instruction
|
||||
slot_index index = 0;
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *instr: block->get_inst_list()){
|
||||
index += 1;
|
||||
indices_.insert({instr, index});
|
||||
}
|
||||
// Liveness analysis
|
||||
// Creates live intervals
|
||||
for(auto i: indices_){
|
||||
ir::value *v = i.first;
|
||||
if(!info_->is_shared(v) || info_->get_reference(v))
|
||||
continue;
|
||||
unsigned start = i.second;
|
||||
unsigned end = start;
|
||||
for(ir::value *u: v->get_users()){
|
||||
start = std::min(start, indices_.at(u));
|
||||
end = std::max(end, indices_.at(u));
|
||||
}
|
||||
intervals_[v] = segment{start, end};
|
||||
}
|
||||
std::cout << "Number of intervals: " << intervals_.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -748,8 +748,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
indices_t b_idx = {idx[1], builder.getInt32(K)};
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
// a = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
// b = ConstantFP::get(builder.getFloatTy(), 1);
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
}
|
||||
result->set_value(idx, res);
|
||||
@@ -846,6 +844,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
// create grids
|
||||
init_grids(fn, dst_builder, sh_mem_ptr);
|
||||
|
||||
|
||||
// iterate through block
|
||||
std::map<ir::basic_block*, BasicBlock*> last_block;
|
||||
for(ir::basic_block *block: fn->blocks()) {
|
||||
@@ -854,10 +853,10 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
|
||||
if(phi_inserted)
|
||||
dst_builder.SetInsertPoint(&*current->getFirstInsertionPt());
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
lower_instruction(i, dst_builder);
|
||||
if(phi_inserted)
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(current);
|
||||
last_block[block] = dst_builder.GetInsertBlock();
|
||||
}
|
||||
|
@@ -28,6 +28,12 @@ void place_shared_copy::run(ir::module &mod) {
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(info_->is_shared(i) && !info_->is_double(i))
|
||||
add_copy(i, builder);
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
info_->replace(cts->get_operand(0), cts);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -109,6 +109,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
std::vector<std::string> paths) {
|
||||
init_llvm();
|
||||
// llvm::legacy::PassManager passes;
|
||||
// passes.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// passes.add(llvm::createVerifierPass());
|
||||
// passes.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
|
@@ -21,6 +21,8 @@ basic_block* basic_block::create(context &ctx, const std::string &name, function
|
||||
|
||||
void basic_block::add_predecessor(basic_block *pred) {
|
||||
preds_.push_back(pred);
|
||||
if(pred)
|
||||
pred->succs_.push_back(this);
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user