[PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of
TK
This commit is contained in:
committed by
Philippe Tillet
parent
fa4ec7ea65
commit
3816f2f259
@@ -369,7 +369,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
|||||||
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
||||||
Value *result_false = false_values->get_value(idx);
|
Value *result_false = false_values->get_value(idx);
|
||||||
if(result_then->getType()->isVectorTy())
|
if(result_then->getType()->isVectorTy())
|
||||||
result_false = builder_->CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
|
result_false = builder_->CreateVectorSplat(vector_size, result_false);
|
||||||
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@@ -255,6 +255,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
return std::unique_ptr<driver::module>();
|
return std::unique_ptr<driver::module>();
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
// return binary
|
// return binary
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||||
|
@@ -5,6 +5,8 @@ import numpy as np
|
|||||||
#import utils
|
#import utils
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
#torch.backends.cudnn.benchmark = True
|
#torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
configs = []
|
configs = []
|
||||||
@@ -31,15 +33,15 @@ MNK = [
|
|||||||
|
|
||||||
# (127008, 768, 576)
|
# (127008, 768, 576)
|
||||||
]
|
]
|
||||||
for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
matmul = lambda a, b: torch.matmul(a, b)
|
# matmul = lambda a, b: torch.matmul(a, b)
|
||||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||||
for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||||
for M, N, K in MNK:
|
#for M, N, K in MNK:
|
||||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
# matmul = lambda a, b: torch.matmul(a, b.t())
|
||||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||||
|
|
||||||
# Relative attention
|
# Relative attention
|
||||||
NTHSE = [
|
NTHSE = [
|
||||||
@@ -70,16 +72,16 @@ NTHSE = [
|
|||||||
# (128, 1024, 8, 256, 256),
|
# (128, 1024, 8, 256, 256),
|
||||||
#(128, 1024, 8, 256, 512)
|
#(128, 1024, 8, 256, 512)
|
||||||
]
|
]
|
||||||
for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||||
for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||||
for N, T, H, S, E in NTHSE:
|
#for N, T, H, S, E in NTHSE:
|
||||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||||
|
|
||||||
# 1D Dense convolution
|
# 1D Dense convolution
|
||||||
NCHKR = [
|
NCHKR = [
|
||||||
(1, 1152, 12602, 512, 3)
|
#(1, 1152, 12602, 512, 3)
|
||||||
]
|
]
|
||||||
for N, C, H, K, R in NCHKR:
|
for N, C, H, K, R in NCHKR:
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||||
@@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR:
|
|||||||
# 2D Dense convolution
|
# 2D Dense convolution
|
||||||
NCHWKRS = [
|
NCHWKRS = [
|
||||||
#(8, 64, 128, 128, 768, 3, 3),
|
#(8, 64, 128, 128, 768, 3, 3),
|
||||||
(8, 128, 64, 64, 256, 3, 3),
|
(128, 3, 32, 32, 64, 3, 3),
|
||||||
(8, 256, 32, 32, 512, 3, 3),
|
#(8, 256, 32, 32, 512, 3, 3),
|
||||||
#(8, 512, 32, 32, 1024, 3, 3)
|
#(8, 512, 32, 32, 1024, 3, 3)
|
||||||
]
|
]
|
||||||
for N, C, H, W, K, R, S in NCHWKRS:
|
for N, C, H, W, K, R, S in NCHWKRS:
|
||||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
|
||||||
configs += [([N, C, H, W],
|
configs += [([N, C, H, W],
|
||||||
[C, R, S, K],
|
[K, C, R, S],
|
||||||
[N, K, H - R + 1, W - R + 1],
|
[N, K, H - R + 1, W - R + 1],
|
||||||
torch_fn,
|
torch_fn,
|
||||||
'nc(h+r)(w+s),crsk->nkhw',
|
'nc(h+r)(w+s),kcrs->nkhw',
|
||||||
dict())]
|
dict())]
|
||||||
|
|
||||||
# 3D Dense Convolution
|
# 3D Dense Convolution
|
||||||
@@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
|||||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||||
# triton output
|
# triton output
|
||||||
|
print(a.size(), b.size())
|
||||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||||
# reference output
|
# reference output
|
||||||
if torch_fn:
|
if torch_fn:
|
||||||
@@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
|||||||
# performance relative to equivalent matrix multiplication
|
# performance relative to equivalent matrix multiplication
|
||||||
ctx = triton.ctx_registry[tc]
|
ctx = triton.ctx_registry[tc]
|
||||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||||
cmp_eqbmm = True
|
cmp_eqbmm = False
|
||||||
if cmp_eqbmm:
|
if cmp_eqbmm:
|
||||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||||
|
@@ -210,7 +210,7 @@ class kernel:
|
|||||||
macros.append((k, values))
|
macros.append((k, values))
|
||||||
opt = libtriton.options_space()
|
opt = libtriton.options_space()
|
||||||
opt.defines = macros
|
opt.defines = macros
|
||||||
opt.num_warps = [2, 4, 8]
|
opt.num_warps = num_warps
|
||||||
# create unique id for this op
|
# create unique id for this op
|
||||||
op_id = libtriton.make_op_id()
|
op_id = libtriton.make_op_id()
|
||||||
self.fw_id[key] = op_id
|
self.fw_id[key] = op_id
|
||||||
|
@@ -110,13 +110,16 @@ __global__ void {name}(
|
|||||||
src += "\n "
|
src += "\n "
|
||||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
||||||
|
src += f", int rem_delta_a __multipleof({multipleof_a})"
|
||||||
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
|
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
src += ", int* AD __noalias __readonly __aligned(16)"
|
||||||
src += "\n "
|
src += "\n "
|
||||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
||||||
|
src += f", int rem_delta_b __multipleof({multipleof_b})"
|
||||||
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
|
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||||
src += ", int* BD"
|
src += ", int* BD"
|
||||||
|
src += "\n"
|
||||||
for ptr in subscripted:
|
for ptr in subscripted:
|
||||||
src += f", int* {ptr}"
|
src += f", int* {ptr}"
|
||||||
src += """) {
|
src += """) {
|
||||||
@@ -142,6 +145,7 @@ __global__ void {name}(
|
|||||||
int off_k = pid_z * div_z;
|
int off_k = pid_z * div_z;
|
||||||
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
||||||
#endif
|
#endif
|
||||||
|
int rem_k = matmul_k % TK;
|
||||||
|
|
||||||
// create ranges
|
// create ranges
|
||||||
"""
|
"""
|
||||||
@@ -204,13 +208,33 @@ __global__ void {name}(
|
|||||||
src += f"""
|
src += f"""
|
||||||
|
|
||||||
// prefetch
|
// prefetch
|
||||||
|
int prefetch_k = select(rem_k > 0, rem_k, TK);
|
||||||
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||||
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||||
bool checkk[TK] = {rk} < matmul_k + off_k;
|
bool checkk[TK] = {rk} < prefetch_k;
|
||||||
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
||||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
|
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
|
||||||
|
|
||||||
|
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||||
|
src += """
|
||||||
|
pa += rem_delta_a;"""
|
||||||
|
else:
|
||||||
|
src += """
|
||||||
|
pa += incda;
|
||||||
|
padelta += TK;
|
||||||
|
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||||
|
|
||||||
|
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||||
|
src += """
|
||||||
|
pb += rem_delta_b;"""
|
||||||
|
else:
|
||||||
|
src += """
|
||||||
|
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||||
|
pbdelta += TK;"""
|
||||||
|
|
||||||
|
src += f"""
|
||||||
// accumulate
|
// accumulate
|
||||||
float acc[TM, TN, TB] = 0;
|
float acc[TM, TN, TB] = 0;
|
||||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||||
@@ -219,16 +243,13 @@ __global__ void {name}(
|
|||||||
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
||||||
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
|
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
|
||||||
#endif
|
#endif
|
||||||
"""
|
|
||||||
|
|
||||||
if not use_lut_a or not use_lut_b:
|
checkk = k > TK;
|
||||||
src += f"""
|
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||||
{rk} += TK;
|
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||||
"""
|
a = *?(checka)pa;
|
||||||
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
|
b = *?(checkb)pb;"""
|
||||||
|
|
||||||
|
|
||||||
if use_lut_a:
|
|
||||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||||
src += """
|
src += """
|
||||||
pa += stride_a_inner;"""
|
pa += stride_a_inner;"""
|
||||||
@@ -237,18 +258,6 @@ __global__ void {name}(
|
|||||||
pa += incda;
|
pa += incda;
|
||||||
padelta += TK;
|
padelta += TK;
|
||||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||||
else:
|
|
||||||
src += """
|
|
||||||
offa = """
|
|
||||||
for i, sym in enumerate(expr_a):
|
|
||||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
|
||||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
|
||||||
if i > 0:
|
|
||||||
src += ' + '
|
|
||||||
src += f"({ccode}) * {stride}\n "
|
|
||||||
src += """;
|
|
||||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||||
@@ -260,11 +269,6 @@ __global__ void {name}(
|
|||||||
pbdelta += TK;"""
|
pbdelta += TK;"""
|
||||||
|
|
||||||
src += f"""
|
src += f"""
|
||||||
checkk = k > TK;
|
|
||||||
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
|
||||||
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
|
||||||
a = *?(checka)pa;
|
|
||||||
b = *?(checkb)pb;
|
|
||||||
}}
|
}}
|
||||||
TYPE c[TM, TN, TB] = acc;
|
TYPE c[TM, TN, TB] = acc;
|
||||||
|
|
||||||
@@ -367,16 +371,27 @@ __global__ void {name}(
|
|||||||
fn = sp.lambdify(args, delta, 'numpy')
|
fn = sp.lambdify(args, delta, 'numpy')
|
||||||
# inner axes values
|
# inner axes values
|
||||||
inner = [dims[d] for d in axes]
|
inner = [dims[d] for d in axes]
|
||||||
k = np.arange(np.prod(inner), dtype=np.int32)
|
inner = np.prod(inner)
|
||||||
|
rem = inner % step
|
||||||
|
rem = rem if rem > 0 else step
|
||||||
|
# k = [0, 1, ..., step,
|
||||||
|
# rem, rem + 1, ... rem + inner]
|
||||||
|
k = np.concatenate((np.arange(step),
|
||||||
|
np.arange(rem, inner))).astype(np.int32)
|
||||||
|
# nextk = [rem, 1 + rem, ..., step + rem,
|
||||||
|
# rem + step, rem + 1 + step, ..., inner + step]
|
||||||
|
nextk = np.concatenate((k[:step] + rem,
|
||||||
|
k[step:] + step))
|
||||||
|
# offsets
|
||||||
off = _einsum.unpack_offset(k, axes, dims)
|
off = _einsum.unpack_offset(k, axes, dims)
|
||||||
nextoff = _einsum.unpack_offset(k + step, axes, dims)
|
nextoff = _einsum.unpack_offset(nextk, axes, dims)
|
||||||
# evaluate deltas
|
# evaluate deltas
|
||||||
args = [s for s in stride]
|
args = [s for s in stride]
|
||||||
args += [off[sk] for sk in axes]
|
args += [off[sk] for sk in axes]
|
||||||
args += [nextoff[sk] for sk in axes]
|
args += [nextoff[sk] for sk in axes]
|
||||||
args += [x for _, x in arrays]
|
args += [x for _, x in arrays]
|
||||||
delta = fn(*args)
|
delta = fn(*args)
|
||||||
return delta, _einsum.lut_mode(delta[:-step])
|
return delta, _einsum.lut_mode(delta[step:-step])
|
||||||
|
|
||||||
############################
|
############################
|
||||||
## Einsum parsing
|
## Einsum parsing
|
||||||
@@ -525,12 +540,17 @@ __global__ void {name}(
|
|||||||
alpha, M, N, K, div_m] +\
|
alpha, M, N, K, div_m] +\
|
||||||
dim_m + dim_n + dim_k + dim_b +\
|
dim_m + dim_n + dim_k + dim_b +\
|
||||||
stride_a + stride_b + stride_c
|
stride_a + stride_b + stride_c
|
||||||
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
|
# LUT for A
|
||||||
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
|
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||||
self.args += [delta_a]
|
self.args += [delta_a[TK], delta_a[0]]
|
||||||
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
|
if lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||||
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
|
self.args += [torch.from_numpy(delta_a).cuda()]
|
||||||
self.args += [delta_b]
|
# LUT for B
|
||||||
|
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||||
|
self.args += [delta_b[TK], delta_b[0]]
|
||||||
|
if lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||||
|
self.args += [torch.from_numpy(delta_b).cuda()]
|
||||||
|
# Einsum dependents
|
||||||
self.args += arrays
|
self.args += arrays
|
||||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||||
triton.cdiv(N, opt.d('TN')),
|
triton.cdiv(N, opt.d('TN')),
|
||||||
@@ -551,6 +571,7 @@ __global__ void {name}(
|
|||||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||||
|
TM, TN, TB, TZ = 64, 64, 1, 1
|
||||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||||
if mask:
|
if mask:
|
||||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||||
@@ -613,6 +634,7 @@ __global__ void {name}(
|
|||||||
ctx.matmul_K = instance.matmul_K
|
ctx.matmul_K = instance.matmul_K
|
||||||
ctx.bench = bench
|
ctx.bench = bench
|
||||||
ctx.forward_ms = speed
|
ctx.forward_ms = speed
|
||||||
|
ctx.mask = mask
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@@ -621,7 +643,7 @@ __global__ void {name}(
|
|||||||
############################
|
############################
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
|
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
||||||
for i, expr in enumerate(sym_x):
|
for i, expr in enumerate(sym_x):
|
||||||
if expr.is_symbol:
|
if expr.is_symbol:
|
||||||
continue
|
continue
|
||||||
@@ -652,9 +674,9 @@ __global__ void {name}(
|
|||||||
expr_a = _einsum.sym_to_expr(sym_a)
|
expr_a = _einsum.sym_to_expr(sym_a)
|
||||||
expr_b = _einsum.sym_to_expr(sym_b)
|
expr_b = _einsum.sym_to_expr(sym_b)
|
||||||
expr_c = _einsum.sym_to_expr(sym_c)
|
expr_c = _einsum.sym_to_expr(sym_c)
|
||||||
expr = f'{expr_c},{expr_b}->{expr_a}'
|
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
|
||||||
da = einsum(expr, dy, b, a.shape, False)
|
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
|
||||||
return None, da, None, None, None
|
return None, da, db, None, None
|
||||||
|
|
||||||
|
|
||||||
einsum = _einsum.apply
|
einsum = _einsum.apply
|
Reference in New Issue
Block a user