[DRIVER] Simplified Driver API by substantially removing reliance on driver::context
This commit is contained in:
@@ -122,7 +122,7 @@ void function::caller::write(std::ofstream &ofs) {
|
||||
ofs << source;
|
||||
}
|
||||
|
||||
void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||
void function::caller::read(std::ifstream &ifs) {
|
||||
// read name
|
||||
std::getline(ifs, name_);
|
||||
// read signature
|
||||
@@ -136,14 +136,14 @@ void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
|
||||
// read module
|
||||
std::string src((std::istreambuf_iterator<char>(ifs)),
|
||||
std::istreambuf_iterator<char>());
|
||||
parent_.reset(new driver::cu_module(ctx, src));
|
||||
parent_.reset(new driver::cu_module(src));
|
||||
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
|
||||
|
||||
}
|
||||
|
||||
function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
|
||||
function::caller::caller(std::ifstream &ifs, const options_t& opt)
|
||||
: opt_(opt) {
|
||||
read(ctx, ifs);
|
||||
read(ifs);
|
||||
}
|
||||
|
||||
function::caller::caller(ir::function *ir,
|
||||
@@ -163,7 +163,12 @@ function::caller::caller(ir::function *ir,
|
||||
}
|
||||
|
||||
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const {
|
||||
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size, const std::map<std::string, std::vector<char>>& csts) const {
|
||||
// copy constants
|
||||
for(const auto& cst: csts){
|
||||
std::unique_ptr<driver::buffer> buffer = parent()->symbol(cst.first.c_str());
|
||||
stream->write(&*buffer, true, 0, cst.second);
|
||||
}
|
||||
// set grid
|
||||
if(_grid.size() > 3)
|
||||
throw std::runtime_error("grid size must be no greater than 3");
|
||||
@@ -188,10 +193,8 @@ std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
|
||||
}
|
||||
|
||||
// create Binary from Triton-IR
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
driver::context *context,
|
||||
const options_t& opt) {
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::device* device, const options_t& opt) {
|
||||
std::unique_ptr<codegen::target> target = device->make_target();
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
@@ -236,17 +239,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
||||
layouts.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
if(allocation.allocated_size() > device->max_shared_memory())
|
||||
throw std::runtime_error("using too much shared memory");
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// create Binary from options
|
||||
void function::make(driver::stream *stream, options_t opt) {
|
||||
void function::make(driver::device *device, options_t opt) {
|
||||
if(callers_.find(opt) != callers_.end())
|
||||
return;
|
||||
// pre-process
|
||||
@@ -263,25 +266,17 @@ void function::make(driver::stream *stream, options_t opt) {
|
||||
// triton-ir -> binary
|
||||
std::unique_ptr<driver::module> bin;
|
||||
// try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
bin = make_bin(*ir, device, opt);
|
||||
// }catch(const std::runtime_error&){
|
||||
// return nullptr;
|
||||
// }
|
||||
// create callable
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
|
||||
auto& call = callers_[opt];
|
||||
// copy constants
|
||||
if(call)
|
||||
for(const auto& cst: cst_){
|
||||
std::unique_ptr<driver::buffer> buffer = call->parent()->symbol(cst.first.c_str());
|
||||
stream->write(&*buffer, true, 0, cst.second);
|
||||
}
|
||||
}
|
||||
|
||||
// precompile all kernels spanned by given options space
|
||||
void function::precompile(driver::stream* stream,
|
||||
const options_space_t& space) {
|
||||
void function::precompile(driver::device* device, const options_space_t& space) {
|
||||
// all ranges
|
||||
std::vector<size_t> ranges;
|
||||
ranges.push_back(space.num_warps.size());
|
||||
@@ -296,7 +291,7 @@ void function::precompile(driver::stream* stream,
|
||||
for(auto D: space.defines)
|
||||
opt.defines[D.first] = D.second[params[i++]];
|
||||
// compile
|
||||
make(stream, opt);
|
||||
make(device, opt);
|
||||
};
|
||||
// multi-threaded compilation
|
||||
_loop_nest(ranges, do_make);
|
||||
@@ -304,8 +299,8 @@ void function::precompile(driver::stream* stream,
|
||||
throw std::runtime_error("could not compile kernel");
|
||||
}
|
||||
|
||||
std::string function::ptx(driver::stream* stream, const options_t& opt) {
|
||||
make(stream, opt);
|
||||
std::string function::ptx(driver::device* device, const options_t& opt) {
|
||||
make(device, opt);
|
||||
const auto& fn = callers_.at(opt);
|
||||
if(!fn)
|
||||
return "";
|
||||
@@ -325,7 +320,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
|
||||
if(x.second == nullptr)
|
||||
throw std::runtime_error("configuration not compiled");
|
||||
caller* current = &*x.second;
|
||||
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); },
|
||||
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size, cst_); },
|
||||
stream, true);
|
||||
ret = (ts < best_ts) ? current : ret;
|
||||
best_ts = std::min(ts, best_ts);
|
||||
@@ -422,14 +417,14 @@ function::function(const std::string &src,
|
||||
src_ = preheader() + src_;
|
||||
}
|
||||
|
||||
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream, driver::device *device) {
|
||||
// pre-compile kernels
|
||||
if(callers_.empty()){
|
||||
precompile(stream, opt_);
|
||||
precompile(device, opt_);
|
||||
}
|
||||
// re-tuning key
|
||||
cache_key_t key;
|
||||
key.first = stream->context()->device();
|
||||
key.first = device;
|
||||
key.second = callers_.begin()->second->retune();
|
||||
// auto-tune if necessary
|
||||
auto it = cache_.find(key);
|
||||
@@ -438,14 +433,14 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_
|
||||
it = cache_.insert({key, best}).first;
|
||||
}
|
||||
// run
|
||||
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size);
|
||||
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size, cst_);
|
||||
}
|
||||
|
||||
void function::operator()(void** args,
|
||||
size_t args_size,
|
||||
const grid_t& grid,
|
||||
driver::stream *stream) {
|
||||
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
|
||||
driver::stream* stream, driver::device *device) {
|
||||
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream, device);
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user