more work
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .dot import dot
|
||||
from .einsum import einsum
|
||||
from .dot import _dot, dot
|
||||
from .einsum import _einsum, einsum
|
||||
|
@@ -3,7 +3,7 @@ import triton
|
||||
class _einsum(triton.function):
|
||||
|
||||
src = """
|
||||
void einsum(TYPE * A, TYPE * B, TYPE * C,
|
||||
void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
||||
int dim_M, int dim_N, int dim_K,
|
||||
int std_A0, int std_B0, int std_C0,
|
||||
int std_A1, int std_B1, int std_C1) {
|
||||
@@ -30,7 +30,7 @@ class _einsum(triton.function):
|
||||
for(int k = dim_K; k > 0; k -= TK) {
|
||||
TYPE a[SHAPE_A] = *pa;
|
||||
TYPE b[SHAPE_B] = *pb;
|
||||
c += a @ b;
|
||||
c += USE_A @ USE_B;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
}
|
||||
@@ -157,6 +157,7 @@ class _einsum(triton.function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, subscripts, a, b):
|
||||
ctx.save_for_backward(a, b)
|
||||
if type(subscripts) is str:
|
||||
einsum_a, einsum_bc = subscripts.split(",")
|
||||
einsum_b, einsum_c = einsum_bc.split("->")
|
||||
@@ -165,8 +166,41 @@ class _einsum(triton.function):
|
||||
|
||||
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
|
||||
einsum_a, einsum_b, einsum_c,
|
||||
a.shape, b.shape
|
||||
a.shape.as_list(), b.shape.as_list()
|
||||
)
|
||||
ctx.trans_a = ta
|
||||
ctx.trans_b = tb
|
||||
ctx.einsum_a = einsum_a
|
||||
ctx.einsum_b = einsum_b
|
||||
ctx.einsum_c = einsum_c
|
||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
a, b = ctx.saved_tensors
|
||||
trans_a = ctx.trans_a
|
||||
trans_b = ctx.trans_b
|
||||
einsum_a = ctx.einsum_a
|
||||
einsum_b = ctx.einsum_b
|
||||
einsum_c = ctx.einsum_c
|
||||
|
||||
if not trans_a and not trans_b: # NN
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
||||
|
||||
elif not trans_a and trans_b: # NT
|
||||
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
||||
|
||||
elif trans_a and not trans_b: # TN
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
||||
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
|
||||
|
||||
elif trans_a and trans_b: # TT (not used)
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
|
||||
|
||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
einsum = _einsum.apply
|
Reference in New Issue
Block a user