more work

This commit is contained in:
Philippe Tillet
2019-10-26 15:10:19 -04:00
parent 76adcb755a
commit 655f43fb5b
5 changed files with 46 additions and 8 deletions

View File

@@ -1,2 +1,2 @@
from .dot import dot
from .einsum import einsum
from .dot import _dot, dot
from .einsum import _einsum, einsum

View File

@@ -3,7 +3,7 @@ import triton
class _einsum(triton.function):
src = """
void einsum(TYPE * A, TYPE * B, TYPE * C,
void einsum_(TYPE * A, TYPE * B, TYPE * C,
int dim_M, int dim_N, int dim_K,
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
@@ -30,7 +30,7 @@ class _einsum(triton.function):
for(int k = dim_K; k > 0; k -= TK) {
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
c += a @ b;
c += USE_A @ USE_B;
pa += TK;
pb += TK;
}
@@ -157,6 +157,7 @@ class _einsum(triton.function):
@staticmethod
def forward(ctx, subscripts, a, b):
ctx.save_for_backward(a, b)
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
einsum_b, einsum_c = einsum_bc.split("->")
@@ -165,8 +166,41 @@ class _einsum(triton.function):
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
a.shape, b.shape
a.shape.as_list(), b.shape.as_list()
)
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
ctx.einsum_b = einsum_b
ctx.einsum_c = einsum_c
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c)
@staticmethod
def backward(ctx, dc):
a, b = ctx.saved_tensors
trans_a = ctx.trans_a
trans_b = ctx.trans_b
einsum_a = ctx.einsum_a
einsum_b = ctx.einsum_b
einsum_c = ctx.einsum_c
if not trans_a and not trans_b: # NN
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
elif not trans_a and trans_b: # NT
da = einsum((einsum_c, einsum_b, einsum_a), dc, b)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
elif trans_a and not trans_b: # TN
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_a, einsum_c, einsum_b), a, dc)
elif trans_a and trans_b: # TT (not used)
da = einsum((einsum_b, einsum_c, einsum_a), b, dc)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a)
return da, db, None, None, None, None, None, None, None, None, None, None
einsum = _einsum.apply