This commit is contained in:
jack-willturner
2020-05-07 13:29:44 +01:00
committed by Philippe Tillet
9 changed files with 43 additions and 24 deletions

View File

@@ -33,8 +33,8 @@ endif()
if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# PyBind11 wrapper source file
file(GLOB_RECURSE PYTHON_SRC python/src/bindings.cc)
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS})
set(PYTHON_SRC bindings.cc)
include_directories("." ${PYTHON_INCLUDE_DIRS})
endif()

1
python/MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
graft src

View File

@@ -61,6 +61,8 @@ class CMakeBuild(build_ext):
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON,
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DLLVM_CONFIG=' + find_llvm()]
# configuration
@@ -82,19 +84,19 @@ class CMakeBuild(build_ext):
self.distribution.get_version())
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
find_llvm()
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
directories = [x[0] for x in os.walk(os.path.join('src', 'include'))]
data = []
for d in directories:
for htype in ['h', 'hpp']:
files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False)
data += [os.path.relpath(f, os.path.pardir) for f in files]
data += [os.path.relpath(f, 'src') for f in files]
setup(
name='triton',

1
python/src/CMakeLists.txt Symbolic link
View File

@@ -0,0 +1 @@
../../CMakeLists.txt

View File

@@ -22,38 +22,39 @@ using namespace triton;
namespace rt = triton::runtime;
std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
typedef std::pair<size_t, size_t> map_key_t;
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
std::map<size_t, double> fp64scalar_map;
std::map<size_t, int64_t> i64scalar_map;
/* Grid map */
void register_grid(size_t id,
void register_grid(const map_key_t& key,
const rt::function::grid_fn_ty& grid_fn) {
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
id_grid_map[key].reset(new rt::function::grid_fn_ty(grid_fn));
}
void delete_grid(size_t id) {
id_grid_map.erase(id);
void delete_grid(const map_key_t& key) {
id_grid_map.erase(key);
}
/* Function map */
void register_fn(size_t id,
void register_fn(const map_key_t& key,
const std::string& src,
const rt::function::options_space_t& opt,
const std::string &cache_ref) {
id_fn_map[id].reset(new rt::function(src, opt, cache_ref));
id_fn_map[key].reset(new rt::function(src, opt, cache_ref));
}
void delete_fn(size_t id) {
id_fn_map.erase(id);
void delete_fn(const map_key_t& key) {
id_fn_map.erase(key);
}
void register_cst(size_t id, const std::string& name, pybind11::buffer& data) {
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
pybind11::buffer_info info = data.request();
id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize);
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
}
void cleanup() {
@@ -487,6 +488,7 @@ void gen_torch_signature(std::ostringstream& oss,
std::string ret_ty = "void";
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t dev_id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
@@ -531,7 +533,7 @@ void gen_torch_make_handles(std::ostream &os,
void gen_torch_make_launch_function(std::ostream &os,
const std::vector<rt::arg_type>& args) {
os << " std::function<void()> run = [&](){\n ";
os << " (*id_fn_map.at(id))({";
os << " (*id_fn_map.at({id, dev_id}))({";
for(unsigned i = 0; i < args.size() ; i++){
std::string name = "arg_" + std::to_string(i);
if(args[i] == rt::BUFFER_T)
@@ -540,7 +542,7 @@ void gen_torch_make_launch_function(std::ostream &os,
os << ", ";
os << name;
}
os << "}, *id_grid_map.at(id), &stream);\n";
os << "}, *id_grid_map.at({id, dev_id}), &stream);\n";
os << " };\n";
os << " run();\n";
os << " if(bench > 0)\n ";
@@ -580,8 +582,9 @@ std::tuple<std::string,
namespace rt = triton::runtime;
namespace drv = triton::driver;
extern std::map<size_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<size_t, std::shared_ptr<rt::function>> id_fn_map;
typedef std::pair<size_t, size_t> map_key_t;
extern std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
extern std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
extern std::map<size_t, int64_t> i64scalar_map;
)";

1
python/src/cmake Symbolic link
View File

@@ -0,0 +1 @@
../../cmake/

1
python/src/include Symbolic link
View File

@@ -0,0 +1 @@
../../include/

1
python/src/lib Symbolic link
View File

@@ -0,0 +1 @@
../../lib/

View File

@@ -185,7 +185,8 @@ class kernel:
opt.defines = macros
opt.num_warps = num_warps
self.op_id = libtriton.make_op_id()
libtriton.register_fn(self.op_id, self.src, opt, os.path.realpath(libtriton.__file__))
self.opt = opt
self.registered = set()
# create pytorch hook
arg_types = libtriton.get_fn_signature(self.src, opt)
self.fw_op = _make_framework_op(arg_types)
@@ -194,6 +195,14 @@ class kernel:
libtriton.register_cst(self.op_id, name, value)
def __call__(self, *args, **kwargs):
for x in args:
if isinstance(x, fw.torch.Tensor):
device = x.device.index
break
# lazily register function for device
if device not in self.registered:
self.registered.add(device)
libtriton.register_fn((self.op_id, device), self.src, self.opt, os.path.realpath(libtriton.__file__))
# launch options
bench = kwargs['bench'] if 'bench' in kwargs else 0
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
@@ -201,8 +210,8 @@ class kernel:
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
libtriton.register_grid(self.op_id, grid)
libtriton.register_grid((self.op_id, device), grid)
# launch
self.fw_op(self.op_id, bench, bench_id, *args)
self.fw_op(self.op_id, device, bench, bench_id, *args)
if bench > 0:
return libtriton.retrieve_scalar(bench_id)