[CORE] Fixed bug for Multi-GPU
This commit is contained in:
committed by
Philippe Tillet
parent
24586e60aa
commit
609ef3a24d
@@ -22,38 +22,39 @@ using namespace triton;
|
|||||||
|
|
||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
|
|
||||||
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
typedef std::pair<size_t, size_t> map_key_t;
|
||||||
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
|
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
std::map<size_t, double> fp64scalar_map;
|
std::map<size_t, double> fp64scalar_map;
|
||||||
std::map<size_t, int64_t> i64scalar_map;
|
std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
/* Grid map */
|
/* Grid map */
|
||||||
|
|
||||||
void register_grid(size_t id,
|
void register_grid(const map_key_t& key,
|
||||||
const rt::function::grid_fn_ty& grid_fn) {
|
const rt::function::grid_fn_ty& grid_fn) {
|
||||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||||
}
|
}
|
||||||
|
|
||||||
void delete_grid(size_t id) {
|
void delete_grid(const map_key_t& key) {
|
||||||
id_grid_map.erase(id);
|
id_grid_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Function map */
|
/* Function map */
|
||||||
|
|
||||||
void register_fn(size_t id,
|
void register_fn(const map_key_t& key,
|
||||||
const std::string& src,
|
const std::string& src,
|
||||||
const rt::function::options_space_t& opt,
|
const rt::function::options_space_t& opt,
|
||||||
const std::string &cache_ref) {
|
const std::string &cache_ref) {
|
||||||
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
|
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
|
||||||
}
|
}
|
||||||
|
|
||||||
void delete_fn(size_t id) {
|
void delete_fn(const map_key_t& key) {
|
||||||
id_fn_map.erase(id);
|
id_fn_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
|
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||||
pybind11::buffer_info info = data.request();
|
pybind11::buffer_info info = data.request();
|
||||||
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
|
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||||
}
|
}
|
||||||
|
|
||||||
void cleanup() {
|
void cleanup() {
|
||||||
@@ -487,6 +488,7 @@ void gen_torch_signature(std::ostringstream& oss,
|
|||||||
std::string ret_ty = "void";
|
std::string ret_ty = "void";
|
||||||
oss << ret_ty << " " << name << "(";
|
oss << ret_ty << " " << name << "(";
|
||||||
oss << "int64_t id, ";
|
oss << "int64_t id, ";
|
||||||
|
oss << "int64_t dev_id, ";
|
||||||
oss << "int64_t bench, ";
|
oss << "int64_t bench, ";
|
||||||
oss << "int64_t bench_id, ";
|
oss << "int64_t bench_id, ";
|
||||||
for(size_t i = 0; i < args.size(); i++) {
|
for(size_t i = 0; i < args.size(); i++) {
|
||||||
@@ -531,7 +533,7 @@ void gen_torch_make_handles(std::ostream &os,
|
|||||||
void gen_torch_make_launch_function(std::ostream &os,
|
void gen_torch_make_launch_function(std::ostream &os,
|
||||||
const std::vector<rt::arg_type>& args) {
|
const std::vector<rt::arg_type>& args) {
|
||||||
os << " std::function<void()> run = [&](){\n ";
|
os << " std::function<void()> run = [&](){\n ";
|
||||||
os << " (*id_fn_map.at(id))({";
|
os << " (*id_fn_map.at({id, dev_id}))({";
|
||||||
for(unsigned i = 0; i < args.size() ; i++){
|
for(unsigned i = 0; i < args.size() ; i++){
|
||||||
std::string name = "arg_" + std::to_string(i);
|
std::string name = "arg_" + std::to_string(i);
|
||||||
if(args[i] == rt::BUFFER_T)
|
if(args[i] == rt::BUFFER_T)
|
||||||
@@ -540,7 +542,7 @@ void gen_torch_make_launch_function(std::ostream &os,
|
|||||||
os << ", ";
|
os << ", ";
|
||||||
os << name;
|
os << name;
|
||||||
}
|
}
|
||||||
os << "}, *id_grid_map.at(id), &stream);\n";
|
os << "}, *id_grid_map.at({id, dev_id}), &stream);\n";
|
||||||
os << " };\n";
|
os << " };\n";
|
||||||
os << " run();\n";
|
os << " run();\n";
|
||||||
os << " if(bench > 0)\n ";
|
os << " if(bench > 0)\n ";
|
||||||
@@ -580,8 +582,9 @@ std::tuple<std::string,
|
|||||||
namespace rt = triton::runtime;
|
namespace rt = triton::runtime;
|
||||||
namespace drv = triton::driver;
|
namespace drv = triton::driver;
|
||||||
|
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
typedef std::pair<size_t, size_t> map_key_t;
|
||||||
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
|
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
|
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
extern std::map<size_t, int64_t> i64scalar_map;
|
extern std::map<size_t, int64_t> i64scalar_map;
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
@@ -185,7 +185,8 @@ class kernel:
|
|||||||
opt.defines = macros
|
opt.defines = macros
|
||||||
opt.num_warps = num_warps
|
opt.num_warps = num_warps
|
||||||
self.op_id = libtriton.make_op_id()
|
self.op_id = libtriton.make_op_id()
|
||||||
libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__))
|
self.opt = opt
|
||||||
|
self.registered = set()
|
||||||
# create pytorch hook
|
# create pytorch hook
|
||||||
arg_types = libtriton.get_fn_signature(self.src, opt)
|
arg_types = libtriton.get_fn_signature(self.src, opt)
|
||||||
self.fw_op = _make_framework_op(arg_types)
|
self.fw_op = _make_framework_op(arg_types)
|
||||||
@@ -194,6 +195,14 @@ class kernel:
|
|||||||
libtriton.register_cst(self.op_id, name, value)
|
libtriton.register_cst(self.op_id, name, value)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
for x in args:
|
||||||
|
if isinstance(x, fw.torch.Tensor):
|
||||||
|
device = x.device.index
|
||||||
|
break
|
||||||
|
# lazily register function for device
|
||||||
|
if device not in self.registered:
|
||||||
|
self.registered.add(device)
|
||||||
|
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
|
||||||
# launch options
|
# launch options
|
||||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
||||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||||
@@ -201,8 +210,8 @@ class kernel:
|
|||||||
if 'grid' not in kwargs:
|
if 'grid' not in kwargs:
|
||||||
raise RuntimeError('Must provide grid for kernel launch')
|
raise RuntimeError('Must provide grid for kernel launch')
|
||||||
grid = kwargs['grid']
|
grid = kwargs['grid']
|
||||||
libtriton.register_grid(self.op_id, grid)
|
libtriton.register_grid((self.op_id, device), grid)
|
||||||
# launch
|
# launch
|
||||||
self.fw_op(self.op_id, bench, bench_id, *args)
|
self.fw_op(self.op_id, device, bench, bench_id, *args)
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
return libtriton.retrieve_scalar(bench_id)
|
return libtriton.retrieve_scalar(bench_id)
|
Reference in New Issue
Block a user