[PYTHON][OPS][EINSUM] Added support for inner tensor strides

This commit is contained in:
Philippe Tillet
2020-02-19 11:50:17 -05:00
committed by Philippe Tillet
parent 4181f9f2af
commit 26fd884d96

View File

@@ -70,6 +70,7 @@ class _einsum(triton.function):
expr_a, expr_b, expr_c,
axes_m, axes_n, axes_k, axes_b,
multipleof_a, multipleof_b, multipleof_c,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted):
@@ -163,7 +164,7 @@ __global__ void {name}(
int offa[TM, TK, TB] = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
@@ -187,7 +188,7 @@ __global__ void {name}(
int offb[TK, TN, TB] = """
for i, sym in enumerate(expr_b):
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
@@ -286,7 +287,7 @@ __global__ void {name}(
// initialize pointers to C
int offc[TM, TN, TB] = """
for i, sym in enumerate(expr_c):
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
if i > 0:
src += ' + '
@@ -321,7 +322,6 @@ __global__ void {name}(
}
"""
# print(src)
ret = triton.kernel(src)
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
@@ -503,8 +503,12 @@ __global__ void {name}(
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
stride_a_last = stride_a[-1]
stride_b_last = stride_b[-1]
stride_c_last = stride_c[-1]
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
# recompile if necessary
cache = _einsum.instance.kernel_cache
if name not in cache:
@@ -513,6 +517,7 @@ __global__ void {name}(
sym_a, sym_b, sym_c,
axes_m, axes_n, axes_k, axes_b,
stride_a_multiple, stride_b_multiple, stride_c_multiple,
stride_a_last, stride_b_last, stride_c_last,
lut_mode_a, lut_mode_b,
delta_a, delta_b,
subscripted)