Files
triton/python/triton/function.py
Philippe Tillet 93a86d4fc6 [PYTHON][TENSORFLOW] Signature of function.forward() does not have to
match signature of kernel anymore
2019-10-30 20:29:23 -04:00

108 lines
3.1 KiB
Python

import triton.frameworks as fw
import triton.utils
class OpContext(object):
def save_for_backward(self, *tensors):
self.to_save = tensors
@property
def saved_tensors(self):
return self.to_save
class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs)
ctx_registry = triton.utils.id_dict()
class function(metaclass = function_meta):
@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
@classmethod
def apply_torch(cls, *args, **kwargs):
class TorchFunction(fw.torch.autograd.Function):
@staticmethod
def forward(ctx, *targs, **tkwargs):
y = cls.forward(ctx, *targs, **tkwargs)
ctx_registry[y] = ctx
return y
@staticmethod
def backward(ctx, grad_output):
return cls.backward(ctx, grad_output)
return TorchFunction.apply(*args, **kwargs)
@classmethod
def extract_tf_tensors(cls, lst, err):
for x in lst:
if x and not isinstance(x, triton.utils.tf_empty_proxy):
raise ValueError('Results of ' + err + ' must be created using triton.empty()')
if x and x.tensor is None:
raise ValueError('Empty tensor never filled during ' + err)
return [x.tensor if x else None for x in lst]
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
# check that all the results stem from triton.empty
# and get the corresponding TF tensors if possible
result = result if isinstance(result, tuple) else (result, )
result = function.extract_tf_tensors(result, 'forward')
# Find a mapping between ::forward arguments and tensorflow op arguments
op = result[0].op
remap_in = dict()
for i, ix in enumerate(op.inputs):
for j, jx in enumerate(args):
if ix is jx:
remap_in[j] = i
remap_out = []
for i, ix in enumerate(result):
for j, jx in enumerate(op.outputs):
if ix is jx:
remap_out.append(j)
# Register backward pass
ctx_registry[op] = ctx
name = op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, *dy):
dy = dy if len(dy) > 1 else dy[0]
# Remap gradient inputs in the right order
grads = [dy[i] for i in remap_out]
# Execute gradient function
grad = cls.backward(ctx_registry[op], grads)
grad = function.extract_tf_tensors(grad, 'backward')
# Remap gradient in the right order
ret = [None] * len(op.inputs)
for i in range(len(grad)):
if i in remap_in:
ret[remap_in[i]] = grad[i]
# Return
return ret
cls.registered = True
# Return tensor
return result
@classmethod
def apply(cls, *args, **kwargs):
if fw.has_tensorflow():
return cls.apply_tensorflow(*args, **kwargs)
elif fw.has_torch():
return cls.apply_torch(*args, **kwargs)
else:
assert False