some cleaning

This commit is contained in:
Philippe Tillet
2019-08-26 17:21:09 -07:00
parent 4075949f80
commit 9ece3eccc6
2 changed files with 43 additions and 41 deletions

View File

@@ -33,13 +33,16 @@ void register_grid(size_t id,
id_grid_map[id] = grid_fn;
}
size_t register_fn(const std::string& src,
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
size_t id = id_grid_map.size();
bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second;
if(!is_inserted)
assert(false);
return id;
}
size_t make_op_id() {
return id_fn_map.size();
}
size_t make_scalar_id() {
@@ -319,6 +322,8 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("register_grid", &register_grid);
m.def("register_fn", &register_fn);
m.def("make_op_id", &make_op_id);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar);
m.def("retrieve_scalar", &retrieve_scalar)
;
}

View File

@@ -175,29 +175,37 @@ def shape(A) :
return lazy_shape(tf.shape(A))
def _make_tensorflow_op(src, outputs, options):
src, name = make_bindings(src, outputs, options)
cache_path = make_cache_path(src)
cpp, so = write_bindings(src, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(so)
return result.__dict__[name]
src, name = make_bindings(src, outputs, options)
cache_path = make_cache_path(src)
cpp, so = write_bindings(src, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(so)
return result.__dict__[name]
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
class op:
def __init__(self, src, outputs):
self.fw_id = dict()
self.fw_ops = dict()
self.fw_grids = dict()
self.src = src
self.outputs = outputs
pass
def __del__(self):
libtriton.unregister_grid(self.id)
def __call__(self, *args, **kwargs):
# recompilation key
# create a new op when defines are different
key = zip(kwargs.keys(), kwargs.values())
# create a new op when non-iterable defines are different
if key not in self.fw_ops:
# code generation options
defines = []
@@ -210,35 +218,24 @@ class op:
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [1, 2, 4, 8]
# register framework op
id = libtriton.register_fn(self.src, opt)
self.fw_ops[key] = (_make_tensorflow_op(self.src, self.outputs, opt), id)
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
self.fw_ops[key] = _make_tensorflow_op(self.src, self.outputs, opt)
# retrieve framework op
op, id = self.fw_ops[key]
# create grid function
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
# register grid function
self.grid = grid
libtriton.register_grid(id, self.grid)
op_id = self.fw_id[key]
op = self.fw_ops[key]
# register grid
grid = _make_grid(args)
libtriton.register_grid(op_id, grid)
self.fw_grids[key] = grid
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
return op(*op_args, id=id)
def make_tensorflow_op(src, outputs, grids):
src, name = make_bindings(src, outputs, grids)
cache_path = make_cache_path(src)
cpp, so = write_bindings(src, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(so)
return result.__dict__[name]
# call framework op
return op(*op_args, id=op_id)
def empty(shapes):
args = [x.handle if isinstance(x, scalar) else x for x in shapes]