[codegen] [membar] view do not write to shared memory
This commit is contained in:
@@ -291,6 +291,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
}
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order = col;
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
@@ -329,7 +330,9 @@ void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
return x->get_type()->get_tile_ranks1() <
|
||||
y->get_type()->get_tile_ranks1();
|
||||
};
|
||||
ir::value *largest = *std::max_element(values.begin(), values.end(), cmp);
|
||||
std::vector<ir::value*> lvalue = values;
|
||||
std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
|
||||
ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
|
@@ -77,12 +77,7 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
auto order = layout_->order;
|
||||
auto shapes = layout_->shapes;
|
||||
shapes[order[0]] += layout_->pad;
|
||||
|
||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
||||
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->double_buffer) {
|
||||
|
@@ -52,7 +52,7 @@ void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
}
|
||||
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i))
|
||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ std::pair<membar::interval_vec_t,
|
||||
get_written_intervals(i, written);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering: write and phi-node read won't intersect
|
||||
// double buffering
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
@@ -125,8 +125,9 @@ void membar::run(ir::module &mod) {
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
if(x.second->double_buffer){
|
||||
auto info = *x.second->double_buffer;
|
||||
safe_war.insert(info.first);
|
||||
safe_war.insert(info.latch);
|
||||
for(ir::value *v: x.second->values)
|
||||
if(v != info.phi)
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -242,6 +242,7 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -243,6 +243,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
// exit(EXIT_FAILURE);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -13,7 +13,7 @@ int main() {
|
||||
for(int TM: std::vector<int>{32, 64})
|
||||
for(int TN: std::vector<int>{32, 64})
|
||||
for(int TK: std::vector<int>{8})
|
||||
for(int nwarps: std::vector<int>{1, 4})
|
||||
for(int nwarps: std::vector<int>{8})
|
||||
for(bool AT: std::array<bool, 2>{false, true})
|
||||
for(bool BT: std::array<bool, 2>{false, true}){
|
||||
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||
|
Reference in New Issue
Block a user