From 26fd884d963b19830a985b8e33187ef1cb9e8a02 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 19 Feb 2020 11:50:17 -0500 Subject: [PATCH] [PYTHON][OPS][EINSUM] Added support for inner tensor strides --- python/triton/ops/einsum.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 3f49ae009..7e7b2189f 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -70,6 +70,7 @@ class _einsum(triton.function): expr_a, expr_b, expr_c, axes_m, axes_n, axes_k, axes_b, multipleof_a, multipleof_b, multipleof_c, + stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, subscripted): @@ -163,7 +164,7 @@ __global__ void {name}( int offa[TM, TK, TB] = """ for i, sym in enumerate(expr_a): ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) - stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1' + stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}' if i > 0: src += ' + ' src += f"({ccode}) * {stride}\n " @@ -187,7 +188,7 @@ __global__ void {name}( int offb[TK, TN, TB] = """ for i, sym in enumerate(expr_b): ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b) - stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1' + stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}' if i > 0: src += ' + ' src += f"({ccode}) * {stride}\n " @@ -286,7 +287,7 @@ __global__ void {name}( // initialize pointers to C int offc[TM, TN, TB] = """ for i, sym in enumerate(expr_c): - stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1' + stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}' ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b) if i > 0: src += ' + ' @@ -321,7 +322,6 @@ __global__ void {name}( } """ - # print(src) ret = triton.kernel(src) if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a) @@ -503,8 +503,12 @@ __global__ void {name}( stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0]) stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0]) stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0]) + stride_a_last = stride_a[-1] + stride_b_last = stride_b[-1] + stride_c_last = stride_c[-1] name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ - f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}' + f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\ + f'_{stride_a_last}_{stride_b_last}_{stride_c_last}' # recompile if necessary cache = _einsum.instance.kernel_cache if name not in cache: @@ -513,6 +517,7 @@ __global__ void {name}( sym_a, sym_b, sym_c, axes_m, axes_n, axes_k, axes_b, stride_a_multiple, stride_b_multiple, stride_c_multiple, + stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, subscripted)