[python] upgraded pybind11 ; forcing torch tensors to be contiguous()

This commit is contained in:
Philippe Tillet
2019-09-05 12:30:51 -04:00
parent 58544d0523
commit 2d6c8311e8
23 changed files with 960 additions and 531 deletions

View File

@@ -29,11 +29,22 @@ def run_tf():
def run_torch():
import torch as th
th.manual_seed(0)
M, N, K = 128, 128, 128
a = th.randn(M, K).cuda()
b = th.randn(K, N).cuda()
th_c = th.matmul(a, b)
tr_c = triton.ops.dot(a, b)
print(tr_c)
b.requires_grad_(True)
#th_c = th.matmul(a, th.t(b))
#th_d = th.matmul(th.t(th_c), b)
tr_c = triton.ops.dot(a, b, False, True)
#tr_d = triton.ops.dot(tr_c, b, True, False)
y = th.sum(tr_c)
#print('backprop', y)
y.backward()
#print('backward done')
print(b.grad)
#th_d.backward()
#print(a.grad)
run_torch()