[codegen] [membar] view do not write to shared memory

This commit is contained in:
Philippe Tillet
2019-10-17 22:38:41 -04:00
parent cf4fbfefee
commit b43454c9b7
6 changed files with 12 additions and 11 deletions

View File

@@ -291,6 +291,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
}
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order = col;
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
@@ -329,7 +330,9 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
return x->get_type()->get_tile_ranks1() <
y->get_type()->get_tile_ranks1();
};
ir::value *largest = *std::max_element(values.begin(), values.end(), cmp);
std::vector<ir::value*> lvalue = values;
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
const auto& axes = axes_->get(largest);
const auto& shapes = largest->get_type()->get_tile_shapes();
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {

View File

@@ -77,12 +77,7 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
std::map<ir::value *, tile *>& tmap)
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = llvm_type(layout_->ty, builder_->getContext());
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
// double-buffered
if(layout_->double_buffer) {

View File

@@ -52,7 +52,7 @@ void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
}
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
if(!dynamic_cast<ir::phi_node*>(i))
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
add_reference(i, res);
}
@@ -99,7 +99,7 @@ std::pair<membar::interval_vec_t,
get_written_intervals(i, written);
bool read_after_write = intersect(new_written_to, read);
bool write_after_read = intersect(new_read_from, written);
// double buffering: write and phi-node read won't intersect
// double buffering
if(safe_war.find(i) != safe_war.end()){
write_after_read = false;
read_after_write = false;
@@ -125,8 +125,9 @@ void membar::run(ir::module &mod) {
for(const auto& x: layouts_->get_all()){
if(x.second->double_buffer){
auto info = *x.second->double_buffer;
safe_war.insert(info.first);
safe_war.insert(info.latch);
for(ir::value *v: x.second->values)
if(v != info.phi)
safe_war.insert(v);
}
}

View File

@@ -242,6 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
// exit(EXIT_FAILURE);
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -243,6 +243,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
return std::unique_ptr<driver::module>();
barriers.run(module);
// ir::print(module, std::cout);
// exit(EXIT_FAILURE);
isel.visit(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -13,7 +13,7 @@ int main() {
for(int TM: std::vector<int>{32, 64})
for(int TN: std::vector<int>{32, 64})
for(int TK: std::vector<int>{8})
for(int nwarps: std::vector<int>{1, 4})
for(int nwarps: std::vector<int>{8})
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});