[PYTHON][OPS][EINSUM] Added support for inner tensor strides
This commit is contained in:
committed by
Philippe Tillet
parent
4181f9f2af
commit
26fd884d96
@@ -70,6 +70,7 @@ class _einsum(triton.function):
|
|||||||
expr_a, expr_b, expr_c,
|
expr_a, expr_b, expr_c,
|
||||||
axes_m, axes_n, axes_k, axes_b,
|
axes_m, axes_n, axes_k, axes_b,
|
||||||
multipleof_a, multipleof_b, multipleof_c,
|
multipleof_a, multipleof_b, multipleof_c,
|
||||||
|
stride_a_last, stride_b_last, stride_c_last,
|
||||||
lut_mode_a, lut_mode_b,
|
lut_mode_a, lut_mode_b,
|
||||||
delta_a, delta_b,
|
delta_a, delta_b,
|
||||||
subscripted):
|
subscripted):
|
||||||
@@ -163,7 +164,7 @@ __global__ void {name}(
|
|||||||
int offa[TM, TK, TB] = """
|
int offa[TM, TK, TB] = """
|
||||||
for i, sym in enumerate(expr_a):
|
for i, sym in enumerate(expr_a):
|
||||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
|
||||||
if i > 0:
|
if i > 0:
|
||||||
src += ' + '
|
src += ' + '
|
||||||
src += f"({ccode}) * {stride}\n "
|
src += f"({ccode}) * {stride}\n "
|
||||||
@@ -187,7 +188,7 @@ __global__ void {name}(
|
|||||||
int offb[TK, TN, TB] = """
|
int offb[TK, TN, TB] = """
|
||||||
for i, sym in enumerate(expr_b):
|
for i, sym in enumerate(expr_b):
|
||||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
|
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
|
||||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
|
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
|
||||||
if i > 0:
|
if i > 0:
|
||||||
src += ' + '
|
src += ' + '
|
||||||
src += f"({ccode}) * {stride}\n "
|
src += f"({ccode}) * {stride}\n "
|
||||||
@@ -286,7 +287,7 @@ __global__ void {name}(
|
|||||||
// initialize pointers to C
|
// initialize pointers to C
|
||||||
int offc[TM, TN, TB] = """
|
int offc[TM, TN, TB] = """
|
||||||
for i, sym in enumerate(expr_c):
|
for i, sym in enumerate(expr_c):
|
||||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
|
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
|
||||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
|
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
src += ' + '
|
src += ' + '
|
||||||
@@ -321,7 +322,6 @@ __global__ void {name}(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# print(src)
|
|
||||||
ret = triton.kernel(src)
|
ret = triton.kernel(src)
|
||||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||||
ret.set_constant('AD', delta_a)
|
ret.set_constant('AD', delta_a)
|
||||||
@@ -503,8 +503,12 @@ __global__ void {name}(
|
|||||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
||||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
||||||
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
||||||
|
stride_a_last = stride_a[-1]
|
||||||
|
stride_b_last = stride_b[-1]
|
||||||
|
stride_c_last = stride_c[-1]
|
||||||
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
|
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
||||||
|
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
||||||
# recompile if necessary
|
# recompile if necessary
|
||||||
cache = _einsum.instance.kernel_cache
|
cache = _einsum.instance.kernel_cache
|
||||||
if name not in cache:
|
if name not in cache:
|
||||||
@@ -513,6 +517,7 @@ __global__ void {name}(
|
|||||||
sym_a, sym_b, sym_c,
|
sym_a, sym_b, sym_c,
|
||||||
axes_m, axes_n, axes_k, axes_b,
|
axes_m, axes_n, axes_k, axes_b,
|
||||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||||
|
stride_a_last, stride_b_last, stride_c_last,
|
||||||
lut_mode_a, lut_mode_b,
|
lut_mode_a, lut_mode_b,
|
||||||
delta_a, delta_b,
|
delta_a, delta_b,
|
||||||
subscripted)
|
subscripted)
|
||||||
|
Reference in New Issue
Block a user