diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d0ad62eb..cef253883 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,8 +33,8 @@ endif() if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file - file(GLOB_RECURSE PYTHON_SRC python/src/bindings.cc) - include_directories(python/src/ ${PYTHON_INCLUDE_DIRS}) + set(PYTHON_SRC bindings.cc) + include_directories("." ${PYTHON_INCLUDE_DIRS}) endif() diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 000000000..150d74bdc --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1 @@ +graft src \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 3e694be44..d867dc122 100644 --- a/python/setup.py +++ b/python/setup.py @@ -61,6 +61,8 @@ class CMakeBuild(build_ext): cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, '-DBUILD_TESTS=OFF', '-DBUILD_PYTHON_MODULE=ON', + #'-DPYTHON_EXECUTABLE=' + sys.executable, + #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, '-DLLVM_CONFIG=' + find_llvm()] # configuration @@ -82,19 +84,19 @@ class CMakeBuild(build_ext): self.distribution.get_version()) if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env) subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) find_llvm() -directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))] +directories = [x[0] for x in os.walk(os.path.join('src', 'include'))] data = [] for d in directories: for htype in ['h', 'hpp']: files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False) - data += [os.path.relpath(f, os.path.pardir) for f in files] + data += [os.path.relpath(f, 'src') for f in files] setup( name='triton', diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt new file mode 120000 index 000000000..8c50e0213 --- /dev/null +++ b/python/src/CMakeLists.txt @@ -0,0 +1 @@ +../../CMakeLists.txt \ No newline at end of file diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 72c159713..696e2e2b6 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -22,38 +22,39 @@ using namespace triton; namespace rt = triton::runtime; -std::map> id_grid_map; -std::map> id_fn_map; +typedef std::pair map_key_t; +std::map> id_grid_map; +std::map> id_fn_map; std::map fp64scalar_map; std::map i64scalar_map; /* Grid map */ -void register_grid(size_t id, +void register_grid(const map_key_t& key, const rt::function::grid_fn_ty& grid_fn) { - id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn)); + id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn)); } -void delete_grid(size_t id) { - id_grid_map.erase(id); +void delete_grid(const map_key_t& key) { + id_grid_map.erase(key); } /* Function map */ -void register_fn(size_t id, +void register_fn(const map_key_t& key, const std::string& src, const rt::function::options_space_t& opt, const std::string &cache_ref) { - id_fn_map[id].reset(new rt::function(src, opt, cache_ref)); + id_fn_map[key].reset(new rt::function(src, opt, cache_ref)); } -void delete_fn(size_t id) { - id_fn_map.erase(id); +void delete_fn(const map_key_t& key) { + id_fn_map.erase(key); } -void register_cst(size_t id, const std::string& name, pybind11::buffer& data) { +void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { pybind11::buffer_info info = data.request(); - id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize); + id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); } void cleanup() { @@ -487,6 +488,7 @@ void gen_torch_signature(std::ostringstream& oss, std::string ret_ty = "void"; oss << ret_ty << " " << name << "("; oss << "int64_t id, "; + oss << "int64_t dev_id, "; oss << "int64_t bench, "; oss << "int64_t bench_id, "; for(size_t i = 0; i < args.size(); i++) { @@ -531,7 +533,7 @@ void gen_torch_make_handles(std::ostream &os, void gen_torch_make_launch_function(std::ostream &os, const std::vector& args) { os << " std::function run = [&](){\n "; - os << " (*id_fn_map.at(id))({"; + os << " (*id_fn_map.at({id, dev_id}))({"; for(unsigned i = 0; i < args.size() ; i++){ std::string name = "arg_" + std::to_string(i); if(args[i] == rt::BUFFER_T) @@ -540,7 +542,7 @@ void gen_torch_make_launch_function(std::ostream &os, os << ", "; os << name; } - os << "}, *id_grid_map.at(id), &stream);\n"; + os << "}, *id_grid_map.at({id, dev_id}), &stream);\n"; os << " };\n"; os << " run();\n"; os << " if(bench > 0)\n "; @@ -580,8 +582,9 @@ std::tuple> id_grid_map; -extern std::map> id_fn_map; +typedef std::pair map_key_t; +extern std::map> id_grid_map; +extern std::map> id_fn_map; extern std::map i64scalar_map; )"; diff --git a/python/src/cmake b/python/src/cmake new file mode 120000 index 000000000..c06bb027c --- /dev/null +++ b/python/src/cmake @@ -0,0 +1 @@ +../../cmake/ \ No newline at end of file diff --git a/python/src/include b/python/src/include new file mode 120000 index 000000000..3611dd266 --- /dev/null +++ b/python/src/include @@ -0,0 +1 @@ +../../include/ \ No newline at end of file diff --git a/python/src/lib b/python/src/lib new file mode 120000 index 000000000..bc1a1ee04 --- /dev/null +++ b/python/src/lib @@ -0,0 +1 @@ +../../lib/ \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 594dbc907..b2afb4ef6 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -185,7 +185,8 @@ class kernel: opt.defines = macros opt.num_warps = num_warps self.op_id = libtriton.make_op_id() - libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__)) + self.opt = opt + self.registered = set() # create pytorch hook arg_types = libtriton.get_fn_signature(self.src, opt) self.fw_op = _make_framework_op(arg_types) @@ -194,6 +195,14 @@ class kernel: libtriton.register_cst(self.op_id, name, value) def __call__(self, *args, **kwargs): + for x in args: + if isinstance(x, fw.torch.Tensor): + device = x.device.index + break + # lazily register function for device + if device not in self.registered: + self.registered.add(device) + libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__)) # launch options bench = kwargs['bench'] if 'bench' in kwargs else 0 bench_id = libtriton.make_scalar_id() if bench > 0 else -1 @@ -201,8 +210,8 @@ class kernel: if 'grid' not in kwargs: raise RuntimeError('Must provide grid for kernel launch') grid = kwargs['grid'] - libtriton.register_grid(self.op_id, grid) + libtriton.register_grid((self.op_id, device), grid) # launch - self.fw_op(self.op_id, bench, bench_id, *args) + self.fw_op(self.op_id, device, bench, bench_id, *args) if bench > 0: return libtriton.retrieve_scalar(bench_id) \ No newline at end of file