[code generation] search space pruning
This commit is contained in:
33
lib/jit.cpp
33
lib/jit.cpp
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
@@ -71,6 +72,7 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_w
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
launch_info_map.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
launch_info_map.num_threads = passes.tune.get_num_threads();
|
||||
@@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps){
|
||||
std::vector<unsigned> current;
|
||||
for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2)
|
||||
current.push_back(x);
|
||||
ranges.push_back(current);
|
||||
}
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
@@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// return;
|
||||
// if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
// return;
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
driver::module* module = driver::module::create(driver_context_, &*ll_module);
|
||||
driver::kernel* kernel = driver::kernel::create(module, "matmul");
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), "matmul"));
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
double perf = benchmark(kernel, info);
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
});
|
||||
@@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
passes.tune.check_constraints(errors);
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// throw std::runtime_error("invalid parameters");
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
|
Reference in New Issue
Block a user