From 4f08d87fedfc6956758e9a729c5e93c4624eadcc Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 26 Nov 2020 00:27:12 -0500 Subject: [PATCH] [DRIVER] Simplified Driver API by substantially removing reliance on driver::context --- include/triton/driver/buffer.h | 11 +++--- include/triton/driver/dispatch.h | 2 ++ include/triton/driver/module.h | 17 ++++----- include/triton/driver/stream.h | 13 +++---- include/triton/runtime/function.h | 18 +++++----- lib/driver/backend.cc | 2 +- lib/driver/buffer.cc | 34 +++++++----------- lib/driver/context.cc | 2 +- lib/driver/dispatch.cc | 2 ++ lib/driver/module.cc | 29 +++++++-------- lib/driver/stream.cc | 35 ++++++------------ lib/runtime/function.cc | 59 ++++++++++++++----------------- python/src/bindings.cc | 5 +-- python/src/launch.cc | 16 ++++++--- tests/bench/conv.cc | 4 +-- tests/bench/copy.cc | 4 +-- tests/bench/dot.cc | 4 +-- tests/common/conv.h | 14 ++++---- tests/common/copy.h | 20 +++++------ tests/common/dot.h | 46 ++++++++++++------------ tests/common/reduce.h | 12 +++---- tests/unit/copy.cc | 4 +-- tests/unit/dot.cc | 4 +-- tests/unit/reduce.cc | 4 +-- 24 files changed, 167 insertions(+), 194 deletions(-) diff --git a/include/triton/driver/buffer.h b/include/triton/driver/buffer.h index f64e81fc2..a8d588640 100755 --- a/include/triton/driver/buffer.h +++ b/include/triton/driver/buffer.h @@ -16,15 +16,14 @@ class stream; // Base class buffer : public polymorphic_resource { public: - buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership); - buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership); + buffer(size_t size, CUdeviceptr cl, bool take_ownership); + buffer(size_t size, host_buffer_t hst, bool take_ownership); uintptr_t addr_as_uintptr_t(); static buffer* create(driver::context* ctx, size_t size); driver::context* context(); size_t size(); protected: - driver::context* context_; size_t size_; }; @@ -32,15 +31,15 @@ protected: class host_buffer: public buffer { public: - host_buffer(driver::context* context, size_t size); + host_buffer(size_t size); }; // CUDA class cu_buffer: public buffer { public: - cu_buffer(driver::context* context, size_t size); - cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership); + cu_buffer(size_t size); + cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership); void set_zero(triton::driver::stream *queue, size_t size); }; diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 2518c8005..3be6c0f7a 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -93,6 +93,7 @@ public: static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); static CUresult cuStreamSynchronize(CUstream hStream); + static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); static CUresult cuStreamDestroy_v2(CUstream hStream); static CUresult cuEventDestroy_v2(CUevent hEvent); static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); @@ -154,6 +155,7 @@ private: static void* cuModuleGetFunction_; static void* cuStreamSynchronize_; static void* cuStreamDestroy_v2_; + static void* cuStreamGetCtx_; static void* cuEventDestroy_v2_; static void* cuMemAlloc_v2_; static void* cuPointerGetAttribute_; diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 474b48cab..760c54575 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -35,26 +35,21 @@ protected: }; public: - module(driver::context* ctx, CUmodule mod, bool has_ownership); - module(driver::context* ctx, host_module_t mod, bool has_ownership); - static module* create(driver::context* ctx, std::unique_ptr src); - driver::context* context() const; + module(CUmodule mod, bool has_ownership); + module(host_module_t mod, bool has_ownership); + static module* create(driver::device* device, std::unique_ptr src); void compile_llvm_module(std::unique_ptr module, const std::string& triple, const std::string &proc, std::string layout, llvm::SmallVectorImpl &buffer, const std::string &features, file_type_t file_type); virtual std::unique_ptr symbol(const char * name) const = 0; - - -protected: - driver::context* ctx_; }; // CPU class host_module: public module{ public: - host_module(driver::context* context, std::unique_ptr module); + host_module(std::unique_ptr module); std::unique_ptr symbol(const char * name) const; }; @@ -63,8 +58,8 @@ class cu_module: public module { std::string compile_llvm_module(std::unique_ptr module, driver::device* device); public: - cu_module(driver::context* context, std::unique_ptr module); - cu_module(driver::context* context, const std::string& source); + cu_module(driver::device* device, std::unique_ptr module); + cu_module(const std::string& source); std::unique_ptr symbol(const char * name) const; const std::string& source() const { return source_; } diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index df4c6ad5f..9874d2a60 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -23,10 +23,10 @@ class cu_buffer; // Base class stream: public polymorphic_resource { public: - stream(driver::context *ctx, CUstream, bool has_ownership); - stream(driver::context *ctx, host_stream_t, bool has_ownership); + stream(CUstream, bool has_ownership); + stream(host_stream_t, bool has_ownership); // factory - static driver::stream* create(driver::context* ctx); + static driver::stream* create(backend_t backend); // accessors driver::context* context() const; // methods @@ -39,16 +39,13 @@ public: { write(buf, blocking, offset, x.size()*sizeof(T), x.data()); } template void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector& x) { read(buf, blocking, offset, x.size()*sizeof(T), x.data()); } - -protected: - driver::context *ctx_; }; // Host class host_stream: public stream { public: // Constructors - host_stream(driver::context *ctx); + host_stream(); // Overridden void synchronize(); @@ -62,7 +59,7 @@ class cu_stream: public stream { public: // Constructors cu_stream(CUstream str, bool take_ownership); - cu_stream(driver::context* context); + cu_stream(); // Overridden void synchronize(); diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 84c0aaa03..778b47d7f 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -87,11 +87,11 @@ private: class caller { public: // constructors - caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt); + caller(std::ifstream& ifs, const options_t& opt); caller(ir::function *ir, std::shared_ptr program, const options_t& opt); // serialization void write(std::ofstream& ofs); - void read(driver::context* ctx, std::ifstream& ifs); + void read(std::ifstream& ifs); // accessors const options_t opt() const { return opt_; } const driver::module* parent() const { return &*parent_; } @@ -101,7 +101,7 @@ private: std::vector retune() const { return retune_; } // entry points - void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const; + void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size, const std::map>& = {}) const; private: std::shared_ptr bin_; @@ -121,9 +121,9 @@ private: // make triton::lang::translation_unit *make_ast(const std::string &src); std::unique_ptr make_ir(Parser &parser); - std::unique_ptr make_bin(ir::module &function, driver::context *context, const options_t &opt); - void make(driver::stream *stream, options_t opt); - void precompile(driver::stream *stream, const options_space_t& tuning_space); + std::unique_ptr make_bin(ir::module &function, driver::device *device, const options_t &opt); + void make(driver::device *device, options_t opt); + void precompile(driver::device *device, const options_space_t& tuning_space); // autotune caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size); @@ -132,10 +132,10 @@ public: public: function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); - void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream); - void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); + void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device); + void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device); void set_cst(const char* name, void* data, size_t n_bytes); - std::string ptx(driver::stream *stream, const options_t& opt); + std::string ptx(driver::device *device, const options_t& opt); private: std::map> cst_; diff --git a/lib/driver/backend.cc b/lib/driver/backend.cc index 036697fba..a52231ff6 100755 --- a/lib/driver/backend.cc +++ b/lib/driver/backend.cc @@ -134,7 +134,7 @@ std::map, driver::kernel*> backend::ker void backend::streams::init(std::list const & contexts){ for(driver::context* ctx : contexts) if(cache_.find(ctx)==cache_.end()) - cache_.insert(std::make_pair(ctx, std::vector{driver::stream::create(ctx)})); + cache_.insert(std::make_pair(ctx, std::vector{driver::stream::create(ctx->backend())})); } void backend::streams::release(){ diff --git a/lib/driver/buffer.cc b/lib/driver/buffer.cc index 7cbefad45..70b8e465d 100755 --- a/lib/driver/buffer.cc +++ b/lib/driver/buffer.cc @@ -35,16 +35,11 @@ namespace driver // -buffer::buffer(driver::context* ctx, size_t size, CUdeviceptr cu, bool take_ownership) - : polymorphic_resource(cu, take_ownership), context_(ctx), size_(size) { } +buffer::buffer(size_t size, CUdeviceptr cu, bool take_ownership) + : polymorphic_resource(cu, take_ownership), size_(size) { } -buffer::buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership) - : polymorphic_resource(hst, take_ownership), context_(ctx), size_(size) { } - - -driver::context* buffer::context() { - return context_; -} +buffer::buffer(size_t size, host_buffer_t hst, bool take_ownership) + : polymorphic_resource(hst, take_ownership), size_(size) { } size_t buffer::size() { return size_; @@ -61,35 +56,32 @@ uintptr_t buffer::addr_as_uintptr_t() { buffer* buffer::create(driver::context* ctx, size_t size) { switch(ctx->backend()){ - case CUDA: return new cu_buffer(ctx, size); - case Host: return new host_buffer(ctx, size); + case CUDA: return new cu_buffer(size); + case Host: return new host_buffer(size); default: throw std::runtime_error("unknown backend"); } } // -host_buffer::host_buffer(driver::context *context, size_t size) - : buffer(context, size, host_buffer_t(), true){ +host_buffer::host_buffer(size_t size) + : buffer(size, host_buffer_t(), true){ hst_->data = new char[size]; } // -cu_buffer::cu_buffer(driver::context* context, size_t size) - : buffer(context, size, CUdeviceptr(), true) { - cu_context::context_switcher ctx_switch(*context_); +cu_buffer::cu_buffer(size_t size) + : buffer(size, CUdeviceptr(), true) { dispatch::cuMemAlloc(&*cu_, size); } -cu_buffer::cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership) - : buffer(context, size, cu, take_ownership){ +cu_buffer::cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership) + : buffer(size, cu, take_ownership){ } -void cu_buffer::set_zero(driver::stream* queue, size_t size) -{ - cu_context::context_switcher ctx_switch(*context_); +void cu_buffer::set_zero(driver::stream* queue, size_t size){ dispatch::cuMemsetD8Async(*cu_, 0, size, *queue->cu()); } diff --git a/lib/driver/context.cc b/lib/driver/context.cc index 8f538cae2..b6091e0ac 100755 --- a/lib/driver/context.cc +++ b/lib/driver/context.cc @@ -121,7 +121,7 @@ cu_context::cu_context(CUcontext context, bool take_ownership): driver::context( cu_context::cu_context(driver::device* device): context(device, CUcontext(), true){ dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, *((driver::cu_device*)dev_)->cu()); - dispatch::cuCtxPopCurrent_v2(NULL); +// dispatch::cuCtxPopCurrent_v2(NULL); } diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index d62d9ec18..3b3af5596 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -154,6 +154,7 @@ CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice) CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *) CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream) CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream) +CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*) CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent) CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t) CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr) @@ -223,6 +224,7 @@ void* dispatch::cuCtxCreate_v2_; void* dispatch::cuModuleGetFunction_; void* dispatch::cuStreamSynchronize_; void* dispatch::cuStreamDestroy_v2_; +void* dispatch::cuStreamGetCtx_; void* dispatch::cuEventDestroy_v2_; void* dispatch::cuMemAlloc_v2_; void* dispatch::cuPointerGetAttribute_; diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 1f9b97a92..d2fe645c8 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -62,22 +62,19 @@ void module::init_llvm() { } } -module::module(driver::context* ctx, CUmodule mod, bool has_ownership) - : polymorphic_resource(mod, has_ownership), ctx_(ctx) { +module::module(CUmodule mod, bool has_ownership) + : polymorphic_resource(mod, has_ownership) { } -module::module(driver::context* ctx, host_module_t mod, bool has_ownership) - : polymorphic_resource(mod, has_ownership), ctx_(ctx) { +module::module(host_module_t mod, bool has_ownership) + : polymorphic_resource(mod, has_ownership) { } -driver::context* module::context() const { - return ctx_; -} -module* module::create(driver::context* ctx, std::unique_ptr src) { - switch(ctx->backend()){ - case CUDA: return new cu_module(ctx, std::move(src)); - case Host: return new host_module(ctx, std::move(src)); +module* module::create(driver::device* device, std::unique_ptr src) { + switch(device->backend()){ + case CUDA: return new cu_module(device, std::move(src)); + case Host: return new host_module(std::move(src)); default: throw std::runtime_error("unknown backend"); } } @@ -130,7 +127,7 @@ void module::compile_llvm_module(std::unique_ptr module, const std // Host // /* ------------------------ */ -host_module::host_module(driver::context * context, std::unique_ptr src): module(context, host_module_t(), true) { +host_module::host_module(std::unique_ptr src): module(host_module_t(), true) { init_llvm(); // create kernel wrapper llvm::LLVMContext &ctx = src->getContext(); @@ -269,10 +266,9 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, } -cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } +cu_module::cu_module(driver::device* device, std::unique_ptr ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { } -cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ - cu_context::context_switcher ctx(*context); +cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){ // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; @@ -285,6 +281,7 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo std::cout << source << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; std::cerr << errbuf << std::endl; +// exit(1); //#endif throw; } @@ -294,7 +291,7 @@ std::unique_ptr cu_module::symbol(const char *name) const{ CUdeviceptr handle; size_t size; dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name); - std::unique_ptr res(new cu_buffer(ctx_, size, handle, false)); + std::unique_ptr res(new cu_buffer(size, handle, false)); return std::move(res); } diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 8d6762767..4fd9e7436 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -43,32 +43,29 @@ namespace driver // Base // /* ------------------------ */ -stream::stream(driver::context *ctx, CUstream cu, bool has_ownership) - : polymorphic_resource(cu, has_ownership), ctx_(ctx) { +stream::stream(CUstream cu, bool has_ownership) + : polymorphic_resource(cu, has_ownership) { } -stream::stream(driver::context *ctx, host_stream_t cl, bool has_ownership) - : polymorphic_resource(cl, has_ownership), ctx_(ctx) { +stream::stream(host_stream_t cl, bool has_ownership) + : polymorphic_resource(cl, has_ownership) { } -driver::stream* stream::create(driver::context* ctx) { - switch(ctx->backend()){ - case CUDA: return new cu_stream(ctx); - case Host: return new host_stream(ctx); +driver::stream* stream::create(backend_t backend) { + switch(backend){ + case CUDA: return new cu_stream(); + case Host: return new host_stream(); default: throw std::runtime_error("unknown backend"); } } -driver::context* stream::context() const { - return ctx_; -} /* ------------------------ */ // Host // /* ------------------------ */ -host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) { +host_stream::host_stream(): stream(host_stream_t(), true) { hst_->pool.reset(new ThreadPool(1)); hst_->futures.reset(new std::vector>()); } @@ -104,28 +101,20 @@ void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset // CUDA // /* ------------------------ */ -inline CUcontext get_context() { - CUcontext result; - dispatch::cuCtxGetCurrent(&result); - return result; -} cu_stream::cu_stream(CUstream str, bool take_ownership): - stream(backend::contexts::import(get_context()), str, take_ownership) { + stream(str, take_ownership) { } -cu_stream::cu_stream(driver::context *context): stream((driver::cu_context*)context, CUstream(), true) { - cu_context::context_switcher ctx_switch(*ctx_); +cu_stream::cu_stream(): stream(CUstream(), true) { dispatch::cuStreamCreate(&*cu_, 0); } void cu_stream::synchronize() { - cu_context::context_switcher ctx_switch(*ctx_); dispatch::cuStreamSynchronize(*cu_); } void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void** args, size_t args_size) { - cu_context::context_switcher ctx_switch(*ctx_); void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, @@ -139,7 +128,6 @@ void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std: } void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { - cu_context::context_switcher ctx_switch(*ctx_); if(blocking) dispatch::cuMemcpyHtoD(*buffer->cu() + offset, ptr, size); else @@ -147,7 +135,6 @@ void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, } void cu_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) { - cu_context::context_switcher ctx_switch(*ctx_); if(blocking) dispatch::cuMemcpyDtoH(ptr, *buffer->cu() + offset, size); else diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index b60074a0f..96a1dbd97 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -122,7 +122,7 @@ void function::caller::write(std::ofstream &ofs) { ofs << source; } -void function::caller::read(driver::context* ctx, std::ifstream &ifs) { +void function::caller::read(std::ifstream &ifs) { // read name std::getline(ifs, name_); // read signature @@ -136,14 +136,14 @@ void function::caller::read(driver::context* ctx, std::ifstream &ifs) { // read module std::string src((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - parent_.reset(new driver::cu_module(ctx, src)); + parent_.reset(new driver::cu_module(src)); bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); } -function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt) +function::caller::caller(std::ifstream &ifs, const options_t& opt) : opt_(opt) { - read(ctx, ifs); + read(ifs); } function::caller::caller(ir::function *ir, @@ -163,7 +163,12 @@ function::caller::caller(ir::function *ir, } -void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const { +void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size, const std::map>& csts) const { + // copy constants + for(const auto& cst: csts){ + std::unique_ptr buffer = parent()->symbol(cst.first.c_str()); + stream->write(&*buffer, true, 0, cst.second); + } // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); @@ -188,10 +193,8 @@ std::unique_ptr function::make_ir(Parser& parser) { } // create Binary from Triton-IR -std::unique_ptr function::make_bin(ir::module &module, - driver::context *context, - const options_t& opt) { - std::unique_ptr target = context->device()->make_target(); +std::unique_ptr function::make_bin(ir::module &module, driver::device* device, const options_t& opt) { + std::unique_ptr target = device->make_target(); // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); @@ -236,17 +239,17 @@ std::unique_ptr function::make_bin(ir::module &module, layouts.run(module); liveness.run(module); allocation.run(module); - if(allocation.allocated_size() > context->device()->max_shared_memory()) + if(allocation.allocated_size() > device->max_shared_memory()) throw std::runtime_error("using too much shared memory"); barriers.run(module); isel.visit(module, *llvm); - std::unique_ptr res(driver::module::create(context, std::move(llvm))); + std::unique_ptr res(driver::module::create(device, std::move(llvm))); return res; } // create Binary from options -void function::make(driver::stream *stream, options_t opt) { +void function::make(driver::device *device, options_t opt) { if(callers_.find(opt) != callers_.end()) return; // pre-process @@ -263,25 +266,17 @@ void function::make(driver::stream *stream, options_t opt) { // triton-ir -> binary std::unique_ptr bin; // try{ - bin = make_bin(*ir, stream->context(), opt); + bin = make_bin(*ir, device, opt); // }catch(const std::runtime_error&){ // return nullptr; // } // create callable ir::function *tmp = ir->get_function_list()[0]; callers_[opt].reset(new caller(tmp, std::move(bin), opt)); - auto& call = callers_[opt]; - // copy constants - if(call) - for(const auto& cst: cst_){ - std::unique_ptr buffer = call->parent()->symbol(cst.first.c_str()); - stream->write(&*buffer, true, 0, cst.second); - } } // precompile all kernels spanned by given options space -void function::precompile(driver::stream* stream, - const options_space_t& space) { +void function::precompile(driver::device* device, const options_space_t& space) { // all ranges std::vector ranges; ranges.push_back(space.num_warps.size()); @@ -296,7 +291,7 @@ void function::precompile(driver::stream* stream, for(auto D: space.defines) opt.defines[D.first] = D.second[params[i++]]; // compile - make(stream, opt); + make(device, opt); }; // multi-threaded compilation _loop_nest(ranges, do_make); @@ -304,8 +299,8 @@ void function::precompile(driver::stream* stream, throw std::runtime_error("could not compile kernel"); } -std::string function::ptx(driver::stream* stream, const options_t& opt) { - make(stream, opt); +std::string function::ptx(driver::device* device, const options_t& opt) { + make(device, opt); const auto& fn = callers_.at(opt); if(!fn) return ""; @@ -325,7 +320,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g if(x.second == nullptr) throw std::runtime_error("configuration not compiled"); caller* current = &*x.second; - double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); }, + double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size, cst_); }, stream, true); ret = (ts < best_ts) ? current : ret; best_ts = std::min(ts, best_ts); @@ -422,14 +417,14 @@ function::function(const std::string &src, src_ = preheader() + src_; } -void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { +void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream, driver::device *device) { // pre-compile kernels if(callers_.empty()){ - precompile(stream, opt_); + precompile(device, opt_); } // re-tuning key cache_key_t key; - key.first = stream->context()->device(); + key.first = device; key.second = callers_.begin()->second->retune(); // auto-tune if necessary auto it = cache_.find(key); @@ -438,14 +433,14 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_ it = cache_.insert({key, best}).first; } // run - (*it->second)(stream, grid_fn(it->second->opt()), args, args_size); + (*it->second)(stream, grid_fn(it->second->opt()), args, args_size, cst_); } void function::operator()(void** args, size_t args_size, const grid_t& grid, - driver::stream *stream) { - return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream); + driver::stream* stream, driver::device *device) { + return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream, device); } diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 040bcaa7e..d67ddc533 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -21,6 +21,7 @@ std::map> id_grid_map; std::map> id_fn_map; CUstream torch_get_cuda_stream(int64_t dev_id); +CUdevice torch_get_cuda_device(int64_t dev_id); /* Grid utilities */ @@ -47,8 +48,8 @@ void delete_fn(const map_key_t& key) { } std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) { - triton::driver::cu_stream stream(torch_get_cuda_stream(key.second), false); - return id_fn_map[key]->ptx(&stream, opt); + triton::driver::cu_device device(torch_get_cuda_device(key.second), false); + return id_fn_map[key]->ptx(&device, opt); } void cleanup() { diff --git a/python/src/launch.cc b/python/src/launch.cc index 9f2426dd0..5bb0bf669 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -31,12 +31,18 @@ void init_host_stream() { if(!host_stream){ host_device.reset(new drv::host_device()); host_context.reset(drv::context::create(&*host_device)); - host_stream.reset(drv::stream::create(&*host_context)); + host_stream.reset(drv::stream::create(host_context->backend())); } } CUstream torch_get_cuda_stream(int64_t dev_id) { - return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); + return (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); +} + +CUdeviceptr torch_get_cuda_device(int64_t dev_id) { + CUdevice ret; + triton::driver::dispatch::cuDeviceGet(&ret, dev_id); + return ret; } void synchronize(int64_t dev_id) { @@ -60,12 +66,12 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, } if(dev_id == -1){ init_host_stream(); - (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); + (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream, &*host_device); } else{ triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false); - triton::driver::context* ctx = stream.context(); - (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); + triton::driver::cu_device device(torch_get_cuda_device(dev_id), false); + (*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream, &device); } } diff --git a/tests/bench/conv.cc b/tests/bench/conv.cc index 93596708d..078029473 100644 --- a/tests/bench/conv.cc +++ b/tests/bench/conv.cc @@ -5,7 +5,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark typedef std::tuple config_t; std::vector configs = { @@ -32,7 +32,7 @@ int main() { for(const auto& c: configs){ std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c; std::cout << "// " << c ; - for(auto perf: bench_conv(stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w)) + for(auto perf: bench_conv(context, stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/bench/copy.cc b/tests/bench/copy.cc index c8c56210a..09869f5cf 100644 --- a/tests/bench/copy.cc +++ b/tests/bench/copy.cc @@ -7,7 +7,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark typedef std::tuple, std::vector, std::vector> config_t; std::vector configs = { @@ -29,7 +29,7 @@ int main() { for(const auto& c: configs){ std::tie(shape, ord_x, ord_y) = c; std::cout << "// " << c << std::flush; - for(auto perf: bench_copy_nd(stream, HALF, shape, ord_x, ord_y)) + for(auto perf: bench_copy_nd(context, stream, HALF, shape, ord_x, ord_y)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 6ec66ecff..9996310ab 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -5,7 +5,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; @@ -65,7 +65,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/conv.h b/tests/common/conv.h index 006b5dc0a..4e194e154 100644 --- a/tests/common/conv.h +++ b/tests/common/conv.h @@ -66,13 +66,13 @@ template<> struct to_string{ }; template -void triton_conv(drv::stream* stream, +void triton_conv(drv::context* context, drv::stream* stream, int Z, int CI, int H, int W, int CO, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, run_mode_t mode, std::vector& bench, bool &test){ std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); - drv::context* context = stream->context(); + drv::device* device = context->device(); int P = (H + 2*pad_h - R)/stride_h + 1; int Q = (W + 2*pad_w - S)/stride_w + 1; @@ -131,19 +131,19 @@ void triton_conv(drv::stream* stream, (size_t)x.D("TZ")}; }; auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); bench.push_back(tflops(triton_ns)); } -std::vector bench_conv(drv::stream* stream, dtype_t dtype, +std::vector bench_conv(drv::context* context, drv::stream* stream, dtype_t dtype, int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S, int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) { std::vector bench; bool test; switch(dtype){ - case HALF: triton_conv(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; - case FLOAT: triton_conv(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; - case DOUBLE: triton_conv(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; + case HALF: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; + case FLOAT: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; + case DOUBLE: triton_conv(context, stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break; default: break; } return bench; diff --git a/tests/common/copy.h b/tests/common/copy.h index aac462789..09c30a952 100644 --- a/tests/common/copy.h +++ b/tests/common/copy.h @@ -79,13 +79,13 @@ template<> struct to_string{ }; template -void triton_copy_nd(drv::stream* stream, const std::vector& shape, +void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vector& shape, const std::vector& x_order, const std::vector& y_order, std::vector> TS, run_mode_t mode, std::vector& bench, bool &test) { std::string ty = to_string::value; size_t dtsize = sizeof(T); - drv::context* context = stream->context(); + drv::device* device = context->device(); // rank size_t rank = shape.size(); @@ -133,7 +133,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector& shape, // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); bench.push_back(gbps(triton_ns)); } @@ -145,7 +145,7 @@ void triton_copy_nd(drv::stream* stream, const std::vector& shape, for(size_t i = 0; i < hx.size(); i++) hx[i] = static_cast((float)rand()/RAND_MAX); stream->write(&*dx, true, 0, hx); - function((void**)&args, sizeof(args), grid, stream); + function((void**)&args, sizeof(args), grid, stream, device); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_copy_nd(hx, ry, shape, x_order, y_order); @@ -153,23 +153,23 @@ void triton_copy_nd(drv::stream* stream, const std::vector& shape, } } -std::vector bench_copy_nd(drv::stream* stream, dtype_t dtype, const std::vector& shape, +std::vector bench_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector& shape, const std::vector& x_order, const std::vector& y_order) { std::vector bench; bool test; switch(dtype){ case HALF: - triton_copy_nd(stream, shape, x_order, y_order, {}, BENCH, bench, test); + triton_copy_nd(context, stream, shape, x_order, y_order, {}, BENCH, bench, test); break; case FLOAT: - triton_copy_nd(stream, shape, x_order, y_order, {}, BENCH, bench, test); + triton_copy_nd(context, stream, shape, x_order, y_order, {}, BENCH, bench, test); break; default: break; } return bench; } -bool test_copy_nd(drv::stream* stream, dtype_t dtype, const std::vector& shape, +bool test_copy_nd(drv::context* context, drv::stream* stream, dtype_t dtype, const std::vector& shape, const std::vector& TS, const std::vector& x_order, const std::vector& y_order) { std::vector bench; @@ -179,10 +179,10 @@ bool test_copy_nd(drv::stream* stream, dtype_t dtype, const std::vector TSS.push_back({std::to_string(d)}); switch(dtype){ case HALF: - triton_copy_nd(stream, shape, x_order, y_order, TSS, TEST, bench, test); + triton_copy_nd(context, stream, shape, x_order, y_order, TSS, TEST, bench, test); break; case FLOAT: - triton_copy_nd(stream, shape, x_order, y_order, TSS, TEST, bench, test); + triton_copy_nd(context, stream, shape, x_order, y_order, TSS, TEST, bench, test); break; default: break; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 6d46add14..28424764e 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -79,14 +79,14 @@ template<> struct to_string{ }; template -void triton_dot(drv::stream* stream, bool AT, bool BT, +void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, int32_t nwarp, const std::vector& a_order, const std::vector& b_order, run_mode_t mode, std::vector& bench, bool &test){ std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); - drv::context* context = stream->context(); + drv::device* device = context->device(); int32_t lda = (AT ^ a_order[0]==1) ? K : M; int32_t ldb = (BT ^ b_order[0]==1) ? N : K; int32_t ldc = N; @@ -148,20 +148,20 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, // metrics if(mode == BENCH){ auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); bench.push_back(tflops(triton_ns)); - // cublas - if(cublas::cublasinit()){ - T alpha(static_cast(1)); - T beta(static_cast(0)); - cublasGemmAlgo_t fastest; - cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); - double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, - &alpha, &*da, lda, &*db, ldb, &beta, &*dc, - ldc, nullptr, fastest); }, stream); - bench.push_back(tflops(cublas_ms)); - } +// // cublas +// if(cublas::cublasinit()){ +// T alpha(static_cast(1)); +// T beta(static_cast(0)); +// cublasGemmAlgo_t fastest; +// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); +// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, +// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, +// ldc, nullptr, fastest); }, stream); +// bench.push_back(tflops(cublas_ms)); +// } } // test triton @@ -179,7 +179,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // run kernel - function((void**)&args, sizeof(args), grid, stream); + function((void**)&args, sizeof(args), grid, stream, device); // write back stream->synchronize(); // compare with CPU @@ -190,21 +190,21 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, } } -std::vector bench_dot(drv::stream* stream, +std::vector bench_dot(drv::context* context, drv::stream* stream, dtype_t dtype, bool AT, bool BT, int32_t M, int32_t N, int32_t K, const std::vector& a_order, const std::vector& b_order) { std::vector bench; bool test; switch(dtype){ - case HALF: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; - case FLOAT: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; - case DOUBLE: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; + case HALF: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; + case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; + case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; default: break; } return bench; } -bool test_dot(drv::stream* stream, +bool test_dot(drv::context* context, drv::stream* stream, dtype_t dtype, bool AT, bool BT, int32_t M, int32_t N, int32_t K, const std::vector& a_order, const std::vector& b_order, @@ -212,9 +212,9 @@ bool test_dot(drv::stream* stream, std::vector bench; bool test = false; switch(dtype){ - case HALF: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; - case FLOAT: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; - case DOUBLE: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; + case HALF: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; + case FLOAT: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; + case DOUBLE: triton_dot(context, stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; default: break; } return test; diff --git a/tests/common/reduce.h b/tests/common/reduce.h index 60923274e..504676ec8 100644 --- a/tests/common/reduce.h +++ b/tests/common/reduce.h @@ -53,7 +53,7 @@ enum run_mode_t { TEST }; -void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, +void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vector& shape_x, int axis, reduce_op_t op, const std::vector& x_order, const std::vector& y_order, std::vector> TS, @@ -61,7 +61,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, typedef float NumericT; std::string ty = "float"; size_t dtsize = sizeof(NumericT); - drv::context* context = stream->context(); + drv::device* device = context->device(); @@ -141,7 +141,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); bench.push_back(gbps(triton_ns)); } @@ -153,7 +153,7 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, init_zeros(hy); init_rand(hx); stream->write(&*dx, true, 0, hx); - function((void**)&args, sizeof(args), grid, stream); + function((void**)&args, sizeof(args), grid, stream, device); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_reduce_nd(ry, hx, op, axis, shape_x); @@ -161,12 +161,12 @@ void triton_reduce_nd(drv::stream* stream, const std::vector& shape_x, } } -bool do_test(drv::stream* stream, std::vector shape, int axis, reduce_op_t op, int nwarp){ +bool do_test(drv::context* context, drv::stream* stream, std::vector shape, int axis, reduce_op_t op, int nwarp){ std::vector bench; bool test; std::vector> TSS; for(int32_t d: shape) TSS.push_back({std::to_string(d)}); - triton_reduce_nd(stream, shape, axis, op, {0, 1, 2}, {0, 1, 2}, TSS, TEST, bench, test); + triton_reduce_nd(context, stream, shape, axis, op, {0, 1, 2}, {0, 1, 2}, TSS, TEST, bench, test); return test; } diff --git a/tests/unit/copy.cc b/tests/unit/copy.cc index 0598ff21f..13a7b6270 100644 --- a/tests/unit/copy.cc +++ b/tests/unit/copy.cc @@ -8,7 +8,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark typedef std::tuple, std::vector, std::vector, std::vector> config_t; std::vector configs; @@ -50,7 +50,7 @@ int main() { bool result = true; for(const auto& c: configs){ std::tie(shape, tile, ord_x, ord_y) = c; - bool pass = test_copy_nd(stream, FLOAT, shape, tile, ord_x, ord_y); + bool pass = test_copy_nd(context, stream, FLOAT, shape, tile, ord_x, ord_y); result = result && pass; std::cout << "// " << c << ", " << pass << std::endl; } diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 300ca9427..896906da1 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -6,7 +6,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to test typedef std::tuple config_t; std::vector configs; @@ -25,7 +25,7 @@ int main() { for(const auto& c: configs){ std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; std::cout << "Testing " << c << " ... " << std::flush; - if(test_dot(stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) + if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) std::cout << " Pass! " << std::endl; else{ std::cout << " Fail! " << std::endl; diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc index 96f2d89f9..24f760c1f 100644 --- a/tests/unit/reduce.cc +++ b/tests/unit/reduce.cc @@ -16,7 +16,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context); + triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to benchmark typedef std::tuple, int, reduce_op_t> config_t; std::vector configs = { @@ -34,7 +34,7 @@ int main() { for(const auto& c: configs){ std::tie(shape, axis, op) = c; std::cout << "Testing " << c << " ... " << std::flush; - if(do_test(stream, shape, axis, op, 1)) + if(do_test(context, stream, shape, axis, op, 1)) std::cout << " Pass! " << std::endl; else std::cout << " Fail! " << std::endl;