[DRIVER] Simplified Driver API by substantially removing reliance on driver::context

This commit is contained in:
Philippe Tillet
2020-11-26 00:27:12 -05:00
parent f42b04d925
commit 4f08d87fed
24 changed files with 167 additions and 194 deletions

View File

@@ -134,7 +134,7 @@ std::map<std::tuple<driver::module*, std::string>, driver::kernel*> backend::ker
void backend::streams::init(std::list<driver::context*> const & contexts){
for(driver::context* ctx : contexts)
if(cache_.find(ctx)==cache_.end())
cache_.insert(std::make_pair(ctx, std::vector<driver::stream*>{driver::stream::create(ctx)}));
cache_.insert(std::make_pair(ctx, std::vector<driver::stream*>{driver::stream::create(ctx->backend())}));
}
void backend::streams::release(){

View File

@@ -35,16 +35,11 @@ namespace driver
//
buffer::buffer(driver::context* ctx, size_t size, CUdeviceptr cu, bool take_ownership)
: polymorphic_resource(cu, take_ownership), context_(ctx), size_(size) { }
buffer::buffer(size_t size, CUdeviceptr cu, bool take_ownership)
: polymorphic_resource(cu, take_ownership), size_(size) { }
buffer::buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), context_(ctx), size_(size) { }
driver::context* buffer::context() {
return context_;
}
buffer::buffer(size_t size, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), size_(size) { }
size_t buffer::size() {
return size_;
@@ -61,35 +56,32 @@ uintptr_t buffer::addr_as_uintptr_t() {
buffer* buffer::create(driver::context* ctx, size_t size) {
switch(ctx->backend()){
case CUDA: return new cu_buffer(ctx, size);
case Host: return new host_buffer(ctx, size);
case CUDA: return new cu_buffer(size);
case Host: return new host_buffer(size);
default: throw std::runtime_error("unknown backend");
}
}
//
host_buffer::host_buffer(driver::context *context, size_t size)
: buffer(context, size, host_buffer_t(), true){
host_buffer::host_buffer(size_t size)
: buffer(size, host_buffer_t(), true){
hst_->data = new char[size];
}
//
cu_buffer::cu_buffer(driver::context* context, size_t size)
: buffer(context, size, CUdeviceptr(), true) {
cu_context::context_switcher ctx_switch(*context_);
cu_buffer::cu_buffer(size_t size)
: buffer(size, CUdeviceptr(), true) {
dispatch::cuMemAlloc(&*cu_, size);
}
cu_buffer::cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership)
: buffer(context, size, cu, take_ownership){
cu_buffer::cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership)
: buffer(size, cu, take_ownership){
}
void cu_buffer::set_zero(driver::stream* queue, size_t size)
{
cu_context::context_switcher ctx_switch(*context_);
void cu_buffer::set_zero(driver::stream* queue, size_t size){
dispatch::cuMemsetD8Async(*cu_, 0, size, *queue->cu());
}

View File

@@ -121,7 +121,7 @@ cu_context::cu_context(CUcontext context, bool take_ownership): driver::context(
cu_context::cu_context(driver::device* device): context(device, CUcontext(), true){
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, *((driver::cu_device*)dev_)->cu());
dispatch::cuCtxPopCurrent_v2(NULL);
// dispatch::cuCtxPopCurrent_v2(NULL);
}

View File

@@ -154,6 +154,7 @@ CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
@@ -223,6 +224,7 @@ void* dispatch::cuCtxCreate_v2_;
void* dispatch::cuModuleGetFunction_;
void* dispatch::cuStreamSynchronize_;
void* dispatch::cuStreamDestroy_v2_;
void* dispatch::cuStreamGetCtx_;
void* dispatch::cuEventDestroy_v2_;
void* dispatch::cuMemAlloc_v2_;
void* dispatch::cuPointerGetAttribute_;

View File

@@ -62,22 +62,19 @@ void module::init_llvm() {
}
}
module::module(driver::context* ctx, CUmodule mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
module::module(CUmodule mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) {
}
module::module(driver::context* ctx, host_module_t mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
module::module(host_module_t mod, bool has_ownership)
: polymorphic_resource(mod, has_ownership) {
}
driver::context* module::context() const {
return ctx_;
}
module* module::create(driver::context* ctx, std::unique_ptr<llvm::Module> src) {
switch(ctx->backend()){
case CUDA: return new cu_module(ctx, std::move(src));
case Host: return new host_module(ctx, std::move(src));
module* module::create(driver::device* device, std::unique_ptr<llvm::Module> src) {
switch(device->backend()){
case CUDA: return new cu_module(device, std::move(src));
case Host: return new host_module(std::move(src));
default: throw std::runtime_error("unknown backend");
}
}
@@ -130,7 +127,7 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
// Host //
/* ------------------------ */
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_t(), true) {
init_llvm();
// create kernel wrapper
llvm::LLVMContext &ctx = src->getContext();
@@ -269,10 +266,9 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
}
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
cu_context::context_switcher ctx(*context);
cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
@@ -285,6 +281,7 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
std::cout << source << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
std::cerr << errbuf << std::endl;
// exit(1);
//#endif
throw;
}
@@ -294,7 +291,7 @@ std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
std::unique_ptr<buffer> res(new cu_buffer(ctx_, size, handle, false));
std::unique_ptr<buffer> res(new cu_buffer(size, handle, false));
return std::move(res);
}

View File

@@ -43,32 +43,29 @@ namespace driver
// Base //
/* ------------------------ */
stream::stream(driver::context *ctx, CUstream cu, bool has_ownership)
: polymorphic_resource(cu, has_ownership), ctx_(ctx) {
stream::stream(CUstream cu, bool has_ownership)
: polymorphic_resource(cu, has_ownership) {
}
stream::stream(driver::context *ctx, host_stream_t cl, bool has_ownership)
: polymorphic_resource(cl, has_ownership), ctx_(ctx) {
stream::stream(host_stream_t cl, bool has_ownership)
: polymorphic_resource(cl, has_ownership) {
}
driver::stream* stream::create(driver::context* ctx) {
switch(ctx->backend()){
case CUDA: return new cu_stream(ctx);
case Host: return new host_stream(ctx);
driver::stream* stream::create(backend_t backend) {
switch(backend){
case CUDA: return new cu_stream();
case Host: return new host_stream();
default: throw std::runtime_error("unknown backend");
}
}
driver::context* stream::context() const {
return ctx_;
}
/* ------------------------ */
// Host //
/* ------------------------ */
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
host_stream::host_stream(): stream(host_stream_t(), true) {
hst_->pool.reset(new ThreadPool(1));
hst_->futures.reset(new std::vector<std::future<void>>());
}
@@ -104,28 +101,20 @@ void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset
// CUDA //
/* ------------------------ */
inline CUcontext get_context() {
CUcontext result;
dispatch::cuCtxGetCurrent(&result);
return result;
}
cu_stream::cu_stream(CUstream str, bool take_ownership):
stream(backend::contexts::import(get_context()), str, take_ownership) {
stream(str, take_ownership) {
}
cu_stream::cu_stream(driver::context *context): stream((driver::cu_context*)context, CUstream(), true) {
cu_context::context_switcher ctx_switch(*ctx_);
cu_stream::cu_stream(): stream(CUstream(), true) {
dispatch::cuStreamCreate(&*cu_, 0);
}
void cu_stream::synchronize() {
cu_context::context_switcher ctx_switch(*ctx_);
dispatch::cuStreamSynchronize(*cu_);
}
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** args, size_t args_size) {
cu_context::context_switcher ctx_switch(*ctx_);
void *config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
@@ -139,7 +128,6 @@ void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std:
}
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
cu_context::context_switcher ctx_switch(*ctx_);
if(blocking)
dispatch::cuMemcpyHtoD(*buffer->cu() + offset, ptr, size);
else
@@ -147,7 +135,6 @@ void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset,
}
void cu_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
cu_context::context_switcher ctx_switch(*ctx_);
if(blocking)
dispatch::cuMemcpyDtoH(ptr, *buffer->cu() + offset, size);
else

View File

@@ -122,7 +122,7 @@ void function::caller::write(std::ofstream &ofs) {
ofs << source;
}
void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
void function::caller::read(std::ifstream &ifs) {
// read name
std::getline(ifs, name_);
// read signature
@@ -136,14 +136,14 @@ void function::caller::read(driver::context* ctx, std::ifstream &ifs) {
// read module
std::string src((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
parent_.reset(new driver::cu_module(ctx, src));
parent_.reset(new driver::cu_module(src));
bin_.reset(driver::kernel::create(&*parent_, name_.c_str()));
}
function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt)
function::caller::caller(std::ifstream &ifs, const options_t& opt)
: opt_(opt) {
read(ctx, ifs);
read(ifs);
}
function::caller::caller(ir::function *ir,
@@ -163,7 +163,12 @@ function::caller::caller(ir::function *ir,
}
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const {
void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size, const std::map<std::string, std::vector<char>>& csts) const {
// copy constants
for(const auto& cst: csts){
std::unique_ptr<driver::buffer> buffer = parent()->symbol(cst.first.c_str());
stream->write(&*buffer, true, 0, cst.second);
}
// set grid
if(_grid.size() > 3)
throw std::runtime_error("grid size must be no greater than 3");
@@ -188,10 +193,8 @@ std::unique_ptr<ir::module> function::make_ir(Parser& parser) {
}
// create Binary from Triton-IR
std::unique_ptr<driver::module> function::make_bin(ir::module &module,
driver::context *context,
const options_t& opt) {
std::unique_ptr<codegen::target> target = context->device()->make_target();
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::device* device, const options_t& opt) {
std::unique_ptr<codegen::target> target = device->make_target();
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
@@ -236,17 +239,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
layouts.run(module);
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
if(allocation.allocated_size() > device->max_shared_memory())
throw std::runtime_error("using too much shared memory");
barriers.run(module);
isel.visit(module, *llvm);
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
std::unique_ptr<driver::module> res(driver::module::create(device, std::move(llvm)));
return res;
}
// create Binary from options
void function::make(driver::stream *stream, options_t opt) {
void function::make(driver::device *device, options_t opt) {
if(callers_.find(opt) != callers_.end())
return;
// pre-process
@@ -263,25 +266,17 @@ void function::make(driver::stream *stream, options_t opt) {
// triton-ir -> binary
std::unique_ptr<driver::module> bin;
// try{
bin = make_bin(*ir, stream->context(), opt);
bin = make_bin(*ir, device, opt);
// }catch(const std::runtime_error&){
// return nullptr;
// }
// create callable
ir::function *tmp = ir->get_function_list()[0];
callers_[opt].reset(new caller(tmp, std::move(bin), opt));
auto& call = callers_[opt];
// copy constants
if(call)
for(const auto& cst: cst_){
std::unique_ptr<driver::buffer> buffer = call->parent()->symbol(cst.first.c_str());
stream->write(&*buffer, true, 0, cst.second);
}
}
// precompile all kernels spanned by given options space
void function::precompile(driver::stream* stream,
const options_space_t& space) {
void function::precompile(driver::device* device, const options_space_t& space) {
// all ranges
std::vector<size_t> ranges;
ranges.push_back(space.num_warps.size());
@@ -296,7 +291,7 @@ void function::precompile(driver::stream* stream,
for(auto D: space.defines)
opt.defines[D.first] = D.second[params[i++]];
// compile
make(stream, opt);
make(device, opt);
};
// multi-threaded compilation
_loop_nest(ranges, do_make);
@@ -304,8 +299,8 @@ void function::precompile(driver::stream* stream,
throw std::runtime_error("could not compile kernel");
}
std::string function::ptx(driver::stream* stream, const options_t& opt) {
make(stream, opt);
std::string function::ptx(driver::device* device, const options_t& opt) {
make(device, opt);
const auto& fn = callers_.at(opt);
if(!fn)
return "";
@@ -325,7 +320,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
if(x.second == nullptr)
throw std::runtime_error("configuration not compiled");
caller* current = &*x.second;
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); },
double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size, cst_); },
stream, true);
ret = (ts < best_ts) ? current : ret;
best_ts = std::min(ts, best_ts);
@@ -422,14 +417,14 @@ function::function(const std::string &src,
src_ = preheader() + src_;
}
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) {
void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream, driver::device *device) {
// pre-compile kernels
if(callers_.empty()){
precompile(stream, opt_);
precompile(device, opt_);
}
// re-tuning key
cache_key_t key;
key.first = stream->context()->device();
key.first = device;
key.second = callers_.begin()->second->retune();
// auto-tune if necessary
auto it = cache_.find(key);
@@ -438,14 +433,14 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_
it = cache_.insert({key, best}).first;
}
// run
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size);
(*it->second)(stream, grid_fn(it->second->opt()), args, args_size, cst_);
}
void function::operator()(void** args,
size_t args_size,
const grid_t& grid,
driver::stream *stream) {
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream);
driver::stream* stream, driver::device *device) {
return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream, device);
}