[PYTHON][OPS][EINSUM] Added support for inner tensor strides
This commit is contained in:
committed by
Philippe Tillet
parent
4181f9f2af
commit
26fd884d96
@@ -70,6 +70,7 @@ class _einsum(triton.function):
|
||||
expr_a, expr_b, expr_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
multipleof_a, multipleof_b, multipleof_c,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted):
|
||||
@@ -163,7 +164,7 @@ __global__ void {name}(
|
||||
int offa[TM, TK, TB] = """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
@@ -187,7 +188,7 @@ __global__ void {name}(
|
||||
int offb[TK, TN, TB] = """
|
||||
for i, sym in enumerate(expr_b):
|
||||
ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
|
||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1'
|
||||
stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
@@ -286,7 +287,7 @@ __global__ void {name}(
|
||||
// initialize pointers to C
|
||||
int offc[TM, TN, TB] = """
|
||||
for i, sym in enumerate(expr_c):
|
||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1'
|
||||
stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}'
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
@@ -321,7 +322,6 @@ __global__ void {name}(
|
||||
}
|
||||
"""
|
||||
|
||||
# print(src)
|
||||
ret = triton.kernel(src)
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
@@ -503,8 +503,12 @@ __global__ void {name}(
|
||||
stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
|
||||
stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
|
||||
stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
|
||||
stride_a_last = stride_a[-1]
|
||||
stride_b_last = stride_b[-1]
|
||||
stride_c_last = stride_c[-1]
|
||||
name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
|
||||
f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\
|
||||
f'_{stride_a_last}_{stride_b_last}_{stride_c_last}'
|
||||
# recompile if necessary
|
||||
cache = _einsum.instance.kernel_cache
|
||||
if name not in cache:
|
||||
@@ -513,6 +517,7 @@ __global__ void {name}(
|
||||
sym_a, sym_b, sym_c,
|
||||
axes_m, axes_n, axes_k, axes_b,
|
||||
stride_a_multiple, stride_b_multiple, stride_c_multiple,
|
||||
stride_a_last, stride_b_last, stride_c_last,
|
||||
lut_mode_a, lut_mode_b,
|
||||
delta_a, delta_b,
|
||||
subscripted)
|
||||
|
Reference in New Issue
Block a user