[code generation] search space pruning

This commit is contained in:
Philippe Tillet
2019-03-25 14:10:24 -07:00
parent deb7a1cc5c
commit 8d35c98920
14 changed files with 131 additions and 118 deletions

View File

@@ -5,6 +5,7 @@
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/driver/device.h"
#include "triton/driver/error.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
@@ -71,6 +72,7 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
passes.selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
launch_info_map.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
launch_info_map.num_threads = passes.tune.get_num_threads();
@@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps){
std::vector<unsigned> current;
for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2)
current.push_back(x);
ranges.push_back(current);
}
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// iterate over parameters
unsigned i;
double best = 0;
@@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
}
passes.tune.init(tt_module);
passes.init(tt_module);
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// return;
// if(passes.tune.get_num_threads() > device->max_threads_per_block())
// return;
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
driver::module* module = driver::module::create(driver_context_, &*ll_module);
driver::kernel* kernel = driver::kernel::create(module, "matmul");
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
launch_information info = launch_info_map_.at("matmul");
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
double perf = benchmark(kernel, info);
double perf;
perf = benchmark(kernel.get(), info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
});
@@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
passes.tune.check_constraints(errors);
if(errors.size())
throw std::runtime_error("invalid parameters");
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// throw std::runtime_error("invalid parameters");
driver::device* device = driver_context_->device();
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code