diff --git a/python/triton/function.py b/python/triton/function.py index 43c8aa7b9..9b71106de 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -1,10 +1,14 @@ import triton.frameworks as fw -import triton.utils +import triton.utils as utils class OpContext(object): + def __init__(self): + self.to_save = [] + def save_for_backward(self, *tensors): - self.to_save = tensors + self.to_save = [x.to_tensor() if isinstance(x, utils.tf_empty_proxy) else x + for x in tensors] @property def saved_tensors(self): @@ -16,7 +20,7 @@ class function_meta(type): cls.registered = False return super(function_meta, cls).__init__(name, bases, attrs) -ctx_registry = triton.utils.id_dict() +ctx_registry = utils.id_dict() class function(metaclass = function_meta): @@ -43,13 +47,39 @@ class function(metaclass = function_meta): @classmethod def extract_tf_tensors(cls, lst, err): + ret = [] for x in lst: - if x and not isinstance(x, triton.utils.tf_empty_proxy): - raise ValueError('Results of ' + err + ' must be created using triton.empty()') - if x and x.tensor is None: - raise ValueError('Empty tensor never filled during ' + err) - return [x.tensor if x else None for x in lst] + if x is None: + ret += [None] + elif isinstance(x, fw.tensorflow.Tensor): + ret += [x] + elif isinstance(x, utils.tf_empty_proxy): + if x.tensor is None: + raise ValueError('Empty tensor never filled during ' + err) + else: + ret += [x.tensor] + else: + raise ValueError('Unsupported return type', type(x)) + return ret + @classmethod + def map_in_to_args(cls, op, args): + ret = dict() + for i, ix in enumerate(op.inputs): + for j, jx in enumerate(args): + if ix is jx: + ret[j] = i + return ret + + @classmethod + def map_res_to_out(cls, op, result): + ret = [] + for i, ix in enumerate(result): + for j, jx in enumerate(op.outputs): + if ix is jx: + ret.append(j) + return ret + @classmethod def apply_tensorflow(cls, *args, **kwargs): ctx = OpContext() @@ -60,30 +90,21 @@ class function(metaclass = function_meta): result = result if isinstance(result, tuple) else (result, ) result = function.extract_tf_tensors(result, 'forward') - # Find a mapping between ::forward arguments and tensorflow op arguments - op = result[0].op - remap_in = dict() - for i, ix in enumerate(op.inputs): - for j, jx in enumerate(args): - if ix is jx: - remap_in[j] = i - remap_out = [] - for i, ix in enumerate(result): - for j, jx in enumerate(op.outputs): - if ix is jx: - remap_out.append(j) - # Register backward pass - ctx_registry[op] = ctx + key = result[0] + op = result[0].op + ctx_registry[key] = ctx + remap_in = cls.map_in_to_args(op, args) + remap_out = cls.map_res_to_out(op, result) name = op.op_def.name if not cls.registered: @fw.tensorflow.RegisterGradient(name) def gradient(op, *dy): - dy = dy if len(dy) > 1 else dy[0] # Remap gradient inputs in the right order - grads = [dy[i] for i in remap_out] + dy = [dy[i] for i in remap_out] + dy = dy if len(dy) > 1 else dy[0] # Execute gradient function - grad = cls.backward(ctx_registry[op], grads) + grad = cls.backward(ctx_registry[key], dy) grad = function.extract_tf_tensors(grad, 'backward') # Remap gradient in the right order ret = [None] * len(op.inputs) @@ -95,7 +116,7 @@ class function(metaclass = function_meta): cls.registered = True # Return tensor - return result + return result[0] if len(result)==1 else result @classmethod def apply(cls, *args, **kwargs): diff --git a/python/triton/kernel.py b/python/triton/kernel.py index ea90e339b..79b3e59ce 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -237,14 +237,20 @@ class kernel: kwargs['T' + str(i)] = x.dtype # launch ret = self.fw_op(*operands, **kwargs) - # fill empty tensors with corresponding values - for j, y in enumerate(ret[0].op.op_def.output_arg): - for i, x in enumerate(ret[0].op.op_def.input_arg): + ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret + op_def = ret[0].op.op_def + # fill empty tensors with corresponding values + for j, y in enumerate(op_def.output_arg): + found = False + for i, x in enumerate(op_def.input_arg): if y.name + '_shape' == x.name: - empty[i].tensor = ret[j] + args[i].tensor = ret[j] + found = True + assert found # store timing information if bench > 0: - bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) + for y in ret: + bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id) elif fw.has_torch(): args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]] self.fw_op(op_id, bench, bench_id, *args) diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 9bb3450c0..9409134d9 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -104,7 +104,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, lambda opt: [1, C], TM = 128) # save - ctx.save_for_backward(x, gamma, beta, mean.tensor, var.tensor) + ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py index 140cd82cd..339fba4c6 100644 --- a/python/triton/ops/dot.py +++ b/python/triton/ops/dot.py @@ -76,10 +76,11 @@ void dot(TYPE * A, TYPE * B, TYPE * C, 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, - grid, bench=bench, - AT = transpose_a, BT = transpose_b, TYPE = dtype, - TM = [64, 128], TN = [64, 128], TK = [8], **macros) + _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, + grid, bench=bench, + AT = transpose_a, BT = transpose_b, TYPE = dtype, + TM = [64, 128], TN = [64, 128], TK = [8], **macros) + return c @staticmethod def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0): diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 167b2aacd..4c3409885 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -169,14 +169,15 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))] TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))] TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16] - return _einsum.kernel(a, b, c, - bmnk[1], bmnk[2], bmnk[3], - std0[0], std0[1], std0[2], - std1[0], std1[1], std1[2], - grid, bench=bench, - **macros, - TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB) - + _einsum.kernel(a, b, c, + bmnk[1], bmnk[2], bmnk[3], + std0[0], std0[1], std0[2], + std1[0], std1[1], std1[2], + grid, bench=bench, + **macros, + TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB) + return c + @staticmethod def forward(ctx, subscripts, a, b, bench = 0): diff --git a/python/triton/utils.py b/python/triton/utils.py index ddc050d15..a51608508 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -13,7 +13,7 @@ class tf_empty_proxy: self.tensor = None def to_tensor(self): - assert self.tensor + assert self.tensor is not None return self.tensor def empty(shape, dtype):