some cleaning
This commit is contained in:
@@ -33,13 +33,16 @@ void register_grid(size_t id,
|
||||
id_grid_map[id] = grid_fn;
|
||||
}
|
||||
|
||||
size_t register_fn(const std::string& src,
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
size_t id = id_grid_map.size();
|
||||
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
|
||||
if(!is_inserted)
|
||||
assert(false);
|
||||
return id;
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
|
||||
size_t make_scalar_id() {
|
||||
@@ -319,6 +322,8 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("make_scalar_id", &make_scalar_id);
|
||||
m.def("retrieve_scalar", &retrieve_scalar);
|
||||
m.def("retrieve_scalar", &retrieve_scalar)
|
||||
;
|
||||
}
|
||||
|
||||
@@ -175,29 +175,37 @@ def shape(A) :
|
||||
return lazy_shape(tf.shape(A))
|
||||
|
||||
def _make_tensorflow_op(src, outputs, options):
|
||||
src, name = make_bindings(src, outputs, options)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result.__dict__[name]
|
||||
src, name = make_bindings(src, outputs, options)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result.__dict__[name]
|
||||
|
||||
def _make_grid(args) :
|
||||
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
||||
def grid(opt):
|
||||
for x in scalars:
|
||||
x.set_assume_initialized()
|
||||
result = args[-1](opt)
|
||||
for x in scalars:
|
||||
x.unset_assume_initialized()
|
||||
return result
|
||||
return grid
|
||||
|
||||
class op:
|
||||
|
||||
def __init__(self, src, outputs):
|
||||
self.fw_id = dict()
|
||||
self.fw_ops = dict()
|
||||
self.fw_grids = dict()
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
libtriton.unregister_grid(self.id)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# recompilation key
|
||||
# create a new op when defines are different
|
||||
key = zip(kwargs.keys(), kwargs.values())
|
||||
# create a new op when non-iterable defines are different
|
||||
if key not in self.fw_ops:
|
||||
# code generation options
|
||||
defines = []
|
||||
@@ -210,35 +218,24 @@ class op:
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [1, 2, 4, 8]
|
||||
# register framework op
|
||||
id = libtriton.register_fn(self.src, opt)
|
||||
self.fw_ops[key] = (_make_tensorflow_op(self.src, self.outputs, opt), id)
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
# register function
|
||||
libtriton.register_fn(op_id, self.src, opt)
|
||||
self.fw_ops[key] = _make_tensorflow_op(self.src, self.outputs, opt)
|
||||
|
||||
# retrieve framework op
|
||||
op, id = self.fw_ops[key]
|
||||
# create grid function
|
||||
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
|
||||
def grid(opt):
|
||||
for x in scalars:
|
||||
x.set_assume_initialized()
|
||||
result = args[-1](opt)
|
||||
for x in scalars:
|
||||
x.unset_assume_initialized()
|
||||
return result
|
||||
# register grid function
|
||||
self.grid = grid
|
||||
libtriton.register_grid(id, self.grid)
|
||||
op_id = self.fw_id[key]
|
||||
op = self.fw_ops[key]
|
||||
# register grid
|
||||
grid = _make_grid(args)
|
||||
libtriton.register_grid(op_id, grid)
|
||||
self.fw_grids[key] = grid
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
return op(*op_args, id=id)
|
||||
|
||||
|
||||
def make_tensorflow_op(src, outputs, grids):
|
||||
src, name = make_bindings(src, outputs, grids)
|
||||
cache_path = make_cache_path(src)
|
||||
cpp, so = write_bindings(src, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result.__dict__[name]
|
||||
# call framework op
|
||||
return op(*op_args, id=op_id)
|
||||
|
||||
def empty(shapes):
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
|
||||
Reference in New Issue
Block a user