[python] [ops] added skeleton for einsum op
This commit is contained in:
@@ -107,7 +107,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw)
|
||||
|
||||
tq = torch.from_numpy(q).contiguous().cuda()
|
||||
tkw = torch.from_numpy(kw).contiguous().cuda()
|
||||
tqk = dot(tq, tkw)
|
||||
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
|
||||
diff = qk - tqk.cpu().numpy()
|
||||
print(np.max(diff))
|
||||
|
||||
|
Reference in New Issue
Block a user