[python][examples] cleaned up dot example

This commit is contained in:
Philippe Tillet
2019-09-05 12:54:35 -04:00
parent 2d6c8311e8
commit 9ab2880fba

View File

@@ -28,23 +28,33 @@ def run_tf():
print("dif: %f" % np.max(dif))
def run_torch():
import torch as th
th.manual_seed(0)
import torch
torch.manual_seed(0)
M, N, K = 128, 128, 128
a = th.randn(M, K).cuda()
b = th.randn(K, N).cuda()
a = torch.randn(M, K).cuda()
b = torch.randn(K, N).cuda()
a.requires_grad_(True)
b.requires_grad_(True)
#th_c = th.matmul(a, th.t(b))
#th_d = th.matmul(th.t(th_c), b)
tr_c = triton.ops.dot(a, b, False, True)
#tr_d = triton.ops.dot(tr_c, b, True, False)
y = th.sum(tr_c)
#print('backprop', y)
y.backward()
#print('backward done')
print(b.grad)
#th_d.backward()
#print(a.grad)
torch_c = torch.matmul(a, torch.t(b))
torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True)
triton_d = triton.ops.dot(triton_c, b, True, False)
triton_y = torch.mean(triton_d)
# torch gradient
torch_y.backward()
torch_da = a.grad.clone()
torch_db = b.grad.clone()
# triton gradient
a.grad.zero_()
b.grad.zero_()
triton_y.backward()
triton_da = a.grad.clone()
triton_db = b.grad.clone()
print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max())
run_torch()