[python] [ops] added skeleton for einsum op

This commit is contained in:
Philippe Tillet
2019-10-21 18:58:02 -04:00
parent 4b0c43bb7b
commit 099918b3c0
3 changed files with 128 additions and 1 deletions

View File

@@ -107,7 +107,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw)
tq = torch.from_numpy(q).contiguous().cuda()
tkw = torch.from_numpy(kw).contiguous().cuda()
tqk = dot(tq, tkw)
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
diff = qk - tqk.cpu().numpy()
print(np.max(diff))