[python] more generic gradient registration
This commit is contained in:
@@ -73,14 +73,14 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
}
|
||||
"""
|
||||
|
||||
class dot_op:
|
||||
class dot_op(triton.op2):
|
||||
|
||||
def __init__(self, transpose_a = False, transpose_b = False):
|
||||
self.dot = triton.op(src, ['C'])
|
||||
self.transpose_a = transpose_a
|
||||
self.transpose_b = transpose_b
|
||||
|
||||
def __call__(self, a, b):
|
||||
def forward(self, a, b):
|
||||
dtype = a.dtype
|
||||
# extract shapes
|
||||
shape_a = triton.shape(a)
|
||||
@@ -104,28 +104,27 @@ class dot_op:
|
||||
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
|
||||
TM = [128], TN = [128], TK = [8])
|
||||
|
||||
def backward(self, op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
da = dot_op(self.transpose_a, self.transpose_b).forward(dy, b)
|
||||
db = dot_op(self.transpose_a, self.transpose_b).forward(a, dy)
|
||||
return [da, db, None, None, None, None, None, None, None]
|
||||
|
||||
|
||||
def dot(a, b, transpose_a = False, transpose_b = False):
|
||||
if (transpose_a, transpose_b) not in dot.ops:
|
||||
dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b)
|
||||
return dot.ops[transpose_a, transpose_b](a, b)
|
||||
dot.ops = dict()
|
||||
|
||||
@tf.RegisterGradient("Dot")
|
||||
def _dot_grad(op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
print(op.triton)
|
||||
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
c = dot(a, b, transpose_a = False, transpose_b = False)
|
||||
print("LULZ")
|
||||
da, db = tf.gradients(c, [a, b])
|
||||
print(da, db)
|
||||
exit
|
||||
da = tf.gradients(c, [a])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float32)
|
||||
hb = np.random.rand(K, N).astype(np.float32)
|
||||
|
@@ -33,12 +33,28 @@ void register_grid(size_t id,
|
||||
id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn));
|
||||
}
|
||||
|
||||
void delete_grid(size_t id) {
|
||||
id_grid_map.erase(id);
|
||||
std::cout << "deleted " << id_grid_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void register_fn(size_t id,
|
||||
const std::string& src,
|
||||
const rt::function::options_space_t& opt) {
|
||||
id_fn_map[id].reset(new rt::function(src, opt));
|
||||
}
|
||||
|
||||
void delete_fn(size_t id) {
|
||||
id_fn_map.erase(id);
|
||||
std::cout << "deleted " << id_fn_map.size() << std::endl;
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
id_grid_map.clear();
|
||||
id_fn_map.clear();
|
||||
i64scalar_map.clear();
|
||||
}
|
||||
|
||||
size_t make_op_id() {
|
||||
return id_fn_map.size();
|
||||
}
|
||||
@@ -453,9 +469,12 @@ PYBIND11_MODULE(libtriton, m) {
|
||||
|
||||
// hooks into triton constructs since frameworks may not use pybind11
|
||||
m.def("register_grid", ®ister_grid);
|
||||
m.def("delete_grid", &delete_grid);
|
||||
m.def("register_fn", ®ister_fn);
|
||||
m.def("delete_fn", &delete_fn);
|
||||
m.def("make_op_id", &make_op_id);
|
||||
m.def("make_scalar_id", &make_scalar_id);
|
||||
m.def("retrieve_scalar", &retrieve_scalar)
|
||||
m.def("retrieve_scalar", &retrieve_scalar);
|
||||
m.def("cleanup", &cleanup);
|
||||
;
|
||||
}
|
||||
|
@@ -13,6 +13,13 @@ import setuptools
|
||||
import libtriton
|
||||
|
||||
|
||||
# clean-up libtriton resources
|
||||
import atexit
|
||||
@atexit.register
|
||||
def cleanup():
|
||||
libtriton.cleanup()
|
||||
|
||||
|
||||
torch_id = 'torch'
|
||||
tensorflow_id = 'tensorflow'
|
||||
|
||||
@@ -20,6 +27,9 @@ torch = None
|
||||
tensorflow = None
|
||||
tf_extra_ops = None
|
||||
|
||||
|
||||
|
||||
|
||||
def _import_torch():
|
||||
global torch
|
||||
if torch is None:
|
||||
@@ -211,19 +221,25 @@ def _make_grid(args) :
|
||||
return grid
|
||||
|
||||
|
||||
class op2:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
result = self.forward(*args, **kwargs)
|
||||
# backprop is defined
|
||||
if(callable(getattr(self, 'backward', None))):
|
||||
_import_tensorflow()
|
||||
@tensorflow.RegisterGradient('Dot')
|
||||
def gradient(op, dy):
|
||||
return self.backward(op, dy)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
class op:
|
||||
|
||||
class _definitions_descriptor:
|
||||
def __init__(self):
|
||||
self.values = dict()
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self.values[value[0]] = value[1]
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.values
|
||||
|
||||
|
||||
def __init__(self, src, outputs, framework = None):
|
||||
self.fw_id = dict()
|
||||
self.fw_ops = dict()
|
||||
@@ -233,9 +249,8 @@ class op:
|
||||
self.framework = _find_framework(framework)
|
||||
if self.framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
tensorflow.Operation.triton = property(op._definitions_descriptor)
|
||||
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new op when defines are different
|
||||
key = zip(kwargs.keys(), kwargs.values())
|
||||
@@ -251,7 +266,7 @@ class op:
|
||||
defines.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = defines
|
||||
opt.num_warps = [1, 2, 4, 8]
|
||||
opt.num_warps = [4]
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
@@ -269,9 +284,7 @@ class op:
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
# call framework op
|
||||
tensor = op(*op_args, id=op_id)
|
||||
tensor.op.triton = ('lol', 1)
|
||||
return tensor
|
||||
return op(*op_args, id=op_id)
|
||||
|
||||
|
||||
class register_gradient:
|
||||
|
Reference in New Issue
Block a user