[python] more generic gradient registration

This commit is contained in:
Philippe Tillet
2019-09-04 03:12:23 -04:00
parent b747959a57
commit cdbc9d4ecd
3 changed files with 61 additions and 30 deletions

View File

@@ -73,14 +73,14 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
}
"""
class dot_op:
class dot_op(triton.op2):
def __init__(self, transpose_a = False, transpose_b = False):
self.dot = triton.op(src, ['C'])
self.transpose_a = transpose_a
self.transpose_b = transpose_b
def __call__(self, a, b):
def forward(self, a, b):
dtype = a.dtype
# extract shapes
shape_a = triton.shape(a)
@@ -104,28 +104,27 @@ class dot_op:
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
TM = [128], TN = [128], TK = [8])
def backward(self, op, dy):
a = op.inputs[0]
b = op.inputs[1]
da = dot_op(self.transpose_a, self.transpose_b).forward(dy, b)
db = dot_op(self.transpose_a, self.transpose_b).forward(a, dy)
return [da, db, None, None, None, None, None, None, None]
def dot(a, b, transpose_a = False, transpose_b = False):
if (transpose_a, transpose_b) not in dot.ops:
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
return dot.ops[transpose_a, transpose_b](a, b)
dot.ops = dict()
@tf.RegisterGradient("Dot")
def _dot_grad(op, dy):
a = op.inputs[0]
b = op.inputs[1]
print(op.triton)
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
c = dot(a, b, transpose_a = False, transpose_b = False)
print("LULZ")
da, db = tf.gradients(c, [a, b])
print(da, db)
exit
da = tf.gradients(c, [a])
# Reference
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)

View File

@@ -33,12 +33,28 @@ void register_grid(size_t id,
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
}
void delete_grid(size_t id) {
id_grid_map.erase(id);
std::cout << "deleted " << id_grid_map.size() << std::endl;
}
void register_fn(size_t id,
const std::string& src,
const rt::function::options_space_t& opt) {
id_fn_map[id].reset(new rt::function(src, opt));
}
void delete_fn(size_t id) {
id_fn_map.erase(id);
std::cout << "deleted " << id_fn_map.size() << std::endl;
}
void cleanup() {
id_grid_map.clear();
id_fn_map.clear();
i64scalar_map.clear();
}
size_t make_op_id() {
return id_fn_map.size();
}
@@ -453,9 +469,12 @@ PYBIND11_MODULE(libtriton, m) {
// hooks into triton constructs since frameworks may not use pybind11
m.def("register_grid", &register_grid);
m.def("delete_grid", &delete_grid);
m.def("register_fn", &register_fn);
m.def("delete_fn", &delete_fn);
m.def("make_op_id", &make_op_id);
m.def("make_scalar_id", &make_scalar_id);
m.def("retrieve_scalar", &retrieve_scalar)
m.def("retrieve_scalar", &retrieve_scalar);
m.def("cleanup", &cleanup);
;
}

View File

@@ -13,6 +13,13 @@ import setuptools
import libtriton
# clean-up libtriton resources
import atexit
@atexit.register
def cleanup():
libtriton.cleanup()
torch_id = 'torch'
tensorflow_id = 'tensorflow'
@@ -20,6 +27,9 @@ torch = None
tensorflow = None
tf_extra_ops = None
def _import_torch():
global torch
if torch is None:
@@ -211,19 +221,25 @@ def _make_grid(args) :
return grid
class op2:
def __init__(self):
pass
def __call__(self, *args, **kwargs):
result = self.forward(*args, **kwargs)
# backprop is defined
if(callable(getattr(self, 'backward', None))):
_import_tensorflow()
@tensorflow.RegisterGradient('Dot')
def gradient(op, dy):
return self.backward(op, dy)
return result
class op:
class _definitions_descriptor:
def __init__(self):
self.values = dict()
def __set__(self, instance, value):
self.values[value[0]] = value[1]
def __get__(self, instance, owner):
return self.values
def __init__(self, src, outputs, framework = None):
self.fw_id = dict()
self.fw_ops = dict()
@@ -233,9 +249,8 @@ class op:
self.framework = _find_framework(framework)
if self.framework == tensorflow_id:
_import_tensorflow()
tensorflow.Operation.triton = property(op._definitions_descriptor)
def __call__(self, *args, **kwargs):
# create a new op when defines are different
key = zip(kwargs.keys(), kwargs.values())
@@ -251,7 +266,7 @@ class op:
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [1, 2, 4, 8]
opt.num_warps = [4]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
@@ -269,9 +284,7 @@ class op:
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
# call framework op
tensor = op(*op_args, id=op_id)
tensor.op.triton = ('lol', 1)
return tensor
return op(*op_args, id=op_id)
class register_gradient: