more fixes
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
#include <iostream>
|
||||||
#include "triton/codegen/analysis/liveness.h"
|
#include "triton/codegen/analysis/liveness.h"
|
||||||
#include "triton/codegen/analysis/meminfo.h"
|
#include "triton/codegen/analysis/meminfo.h"
|
||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
|
@@ -82,6 +82,10 @@ void add_copy(ir::value *x, ir::builder &builder) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void meminfo::run(ir::module &mod) {
|
void meminfo::run(ir::module &mod) {
|
||||||
|
// shared_.clear();
|
||||||
|
// refs_.clear();
|
||||||
|
// double_.clear();
|
||||||
|
|
||||||
// Add shared copies
|
// Add shared copies
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
ir::builder builder(mod.get_context());
|
ir::builder builder(mod.get_context());
|
||||||
|
@@ -88,8 +88,10 @@ void coalesce::run(ir::module &mod) {
|
|||||||
builder.set_insert_point(it);
|
builder.set_insert_point(it);
|
||||||
// found a load; write to shared memory and stop recursion
|
// found a load; write to shared memory and stop recursion
|
||||||
ir::instruction *n_op = nullptr;
|
ir::instruction *n_op = nullptr;
|
||||||
if(mem_->is_shared(i_op))
|
if(mem_->is_shared(i_op)){
|
||||||
|
i_op->add_use(cloned);
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
|
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
|
||||||
n_op = ir::copy_to_shared_inst::create(ld);
|
n_op = ir::copy_to_shared_inst::create(ld);
|
||||||
// not a load; rematerialize and add to worklist
|
// not a load; rematerialize and add to worklist
|
||||||
|
@@ -229,7 +229,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
dce.run(module);
|
dce.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
ir::print(module, std::cout);
|
// ir::print(module, std::cout);
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
|
@@ -45,10 +45,10 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"TYPE", {ty}});
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {8};
|
opt.num_warps = {2, 4, 8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
|
Reference in New Issue
Block a user