[PYTHON][TENSORFLOW] More bugfixes for forward/backward signatures
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
import triton.frameworks as fw
|
import triton.frameworks as fw
|
||||||
import triton.utils
|
import triton.utils as utils
|
||||||
|
|
||||||
class OpContext(object):
|
class OpContext(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.to_save = []
|
||||||
|
|
||||||
def save_for_backward(self, *tensors):
|
def save_for_backward(self, *tensors):
|
||||||
self.to_save = tensors
|
self.to_save = [x.to_tensor() if isinstance(x, utils.tf_empty_proxy) else x
|
||||||
|
for x in tensors]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def saved_tensors(self):
|
def saved_tensors(self):
|
||||||
@@ -16,7 +20,7 @@ class function_meta(type):
|
|||||||
cls.registered = False
|
cls.registered = False
|
||||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||||
|
|
||||||
ctx_registry = triton.utils.id_dict()
|
ctx_registry = utils.id_dict()
|
||||||
|
|
||||||
class function(metaclass = function_meta):
|
class function(metaclass = function_meta):
|
||||||
|
|
||||||
@@ -43,13 +47,39 @@ class function(metaclass = function_meta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_tf_tensors(cls, lst, err):
|
def extract_tf_tensors(cls, lst, err):
|
||||||
|
ret = []
|
||||||
for x in lst:
|
for x in lst:
|
||||||
if x and not isinstance(x, triton.utils.tf_empty_proxy):
|
if x is None:
|
||||||
raise ValueError('Results of ' + err + ' must be created using triton.empty()')
|
ret += [None]
|
||||||
if x and x.tensor is None:
|
elif isinstance(x, fw.tensorflow.Tensor):
|
||||||
raise ValueError('Empty tensor never filled during ' + err)
|
ret += [x]
|
||||||
return [x.tensor if x else None for x in lst]
|
elif isinstance(x, utils.tf_empty_proxy):
|
||||||
|
if x.tensor is None:
|
||||||
|
raise ValueError('Empty tensor never filled during ' + err)
|
||||||
|
else:
|
||||||
|
ret += [x.tensor]
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported return type', type(x))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def map_in_to_args(cls, op, args):
|
||||||
|
ret = dict()
|
||||||
|
for i, ix in enumerate(op.inputs):
|
||||||
|
for j, jx in enumerate(args):
|
||||||
|
if ix is jx:
|
||||||
|
ret[j] = i
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def map_res_to_out(cls, op, result):
|
||||||
|
ret = []
|
||||||
|
for i, ix in enumerate(result):
|
||||||
|
for j, jx in enumerate(op.outputs):
|
||||||
|
if ix is jx:
|
||||||
|
ret.append(j)
|
||||||
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_tensorflow(cls, *args, **kwargs):
|
def apply_tensorflow(cls, *args, **kwargs):
|
||||||
ctx = OpContext()
|
ctx = OpContext()
|
||||||
@@ -60,30 +90,21 @@ class function(metaclass = function_meta):
|
|||||||
result = result if isinstance(result, tuple) else (result, )
|
result = result if isinstance(result, tuple) else (result, )
|
||||||
result = function.extract_tf_tensors(result, 'forward')
|
result = function.extract_tf_tensors(result, 'forward')
|
||||||
|
|
||||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
|
||||||
op = result[0].op
|
|
||||||
remap_in = dict()
|
|
||||||
for i, ix in enumerate(op.inputs):
|
|
||||||
for j, jx in enumerate(args):
|
|
||||||
if ix is jx:
|
|
||||||
remap_in[j] = i
|
|
||||||
remap_out = []
|
|
||||||
for i, ix in enumerate(result):
|
|
||||||
for j, jx in enumerate(op.outputs):
|
|
||||||
if ix is jx:
|
|
||||||
remap_out.append(j)
|
|
||||||
|
|
||||||
# Register backward pass
|
# Register backward pass
|
||||||
ctx_registry[op] = ctx
|
key = result[0]
|
||||||
|
op = result[0].op
|
||||||
|
ctx_registry[key] = ctx
|
||||||
|
remap_in = cls.map_in_to_args(op, args)
|
||||||
|
remap_out = cls.map_res_to_out(op, result)
|
||||||
name = op.op_def.name
|
name = op.op_def.name
|
||||||
if not cls.registered:
|
if not cls.registered:
|
||||||
@fw.tensorflow.RegisterGradient(name)
|
@fw.tensorflow.RegisterGradient(name)
|
||||||
def gradient(op, *dy):
|
def gradient(op, *dy):
|
||||||
dy = dy if len(dy) > 1 else dy[0]
|
|
||||||
# Remap gradient inputs in the right order
|
# Remap gradient inputs in the right order
|
||||||
grads = [dy[i] for i in remap_out]
|
dy = [dy[i] for i in remap_out]
|
||||||
|
dy = dy if len(dy) > 1 else dy[0]
|
||||||
# Execute gradient function
|
# Execute gradient function
|
||||||
grad = cls.backward(ctx_registry[op], grads)
|
grad = cls.backward(ctx_registry[key], dy)
|
||||||
grad = function.extract_tf_tensors(grad, 'backward')
|
grad = function.extract_tf_tensors(grad, 'backward')
|
||||||
# Remap gradient in the right order
|
# Remap gradient in the right order
|
||||||
ret = [None] * len(op.inputs)
|
ret = [None] * len(op.inputs)
|
||||||
@@ -95,7 +116,7 @@ class function(metaclass = function_meta):
|
|||||||
cls.registered = True
|
cls.registered = True
|
||||||
|
|
||||||
# Return tensor
|
# Return tensor
|
||||||
return result
|
return result[0] if len(result)==1 else result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, *args, **kwargs):
|
def apply(cls, *args, **kwargs):
|
||||||
|
@@ -237,14 +237,20 @@ class kernel:
|
|||||||
kwargs['T' + str(i)] = x.dtype
|
kwargs['T' + str(i)] = x.dtype
|
||||||
# launch
|
# launch
|
||||||
ret = self.fw_op(*operands, **kwargs)
|
ret = self.fw_op(*operands, **kwargs)
|
||||||
# fill empty tensors with corresponding values
|
ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
|
||||||
for j, y in enumerate(ret[0].op.op_def.output_arg):
|
op_def = ret[0].op.op_def
|
||||||
for i, x in enumerate(ret[0].op.op_def.input_arg):
|
# fill empty tensors with corresponding values
|
||||||
|
for j, y in enumerate(op_def.output_arg):
|
||||||
|
found = False
|
||||||
|
for i, x in enumerate(op_def.input_arg):
|
||||||
if y.name + '_shape' == x.name:
|
if y.name + '_shape' == x.name:
|
||||||
empty[i].tensor = ret[j]
|
args[i].tensor = ret[j]
|
||||||
|
found = True
|
||||||
|
assert found
|
||||||
# store timing information
|
# store timing information
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
for y in ret:
|
||||||
|
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||||
self.fw_op(op_id, bench, bench_id, *args)
|
self.fw_op(op_id, bench, bench_id, *args)
|
||||||
|
@@ -104,7 +104,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
|||||||
lambda opt: [1, C],
|
lambda opt: [1, C],
|
||||||
TM = 128)
|
TM = 128)
|
||||||
# save
|
# save
|
||||||
ctx.save_for_backward(x, gamma, beta, mean.tensor, var.tensor)
|
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
@@ -76,10 +76,11 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||||
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
_dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||||
grid, bench=bench,
|
grid, bench=bench,
|
||||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||||
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
|
||||||
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
|
def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
|
||||||
|
@@ -169,14 +169,15 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
||||||
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
||||||
TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16]
|
TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16]
|
||||||
return _einsum.kernel(a, b, c,
|
_einsum.kernel(a, b, c,
|
||||||
bmnk[1], bmnk[2], bmnk[3],
|
bmnk[1], bmnk[2], bmnk[3],
|
||||||
std0[0], std0[1], std0[2],
|
std0[0], std0[1], std0[2],
|
||||||
std1[0], std1[1], std1[2],
|
std1[0], std1[1], std1[2],
|
||||||
grid, bench=bench,
|
grid, bench=bench,
|
||||||
**macros,
|
**macros,
|
||||||
TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB)
|
TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, subscripts, a, b, bench = 0):
|
def forward(ctx, subscripts, a, b, bench = 0):
|
||||||
|
@@ -13,7 +13,7 @@ class tf_empty_proxy:
|
|||||||
self.tensor = None
|
self.tensor = None
|
||||||
|
|
||||||
def to_tensor(self):
|
def to_tensor(self):
|
||||||
assert self.tensor
|
assert self.tensor is not None
|
||||||
return self.tensor
|
return self.tensor
|
||||||
|
|
||||||
def empty(shape, dtype):
|
def empty(shape, dtype):
|
||||||
|
Reference in New Issue
Block a user