[python] upgraded pybind11 ; forcing torch tensors to be contiguous()
This commit is contained in:
@@ -29,11 +29,22 @@ def run_tf():
|
||||
|
||||
def run_torch():
|
||||
import torch as th
|
||||
th.manual_seed(0)
|
||||
M, N, K = 128, 128, 128
|
||||
a = th.randn(M, K).cuda()
|
||||
b = th.randn(K, N).cuda()
|
||||
th_c = th.matmul(a, b)
|
||||
tr_c = triton.ops.dot(a, b)
|
||||
print(tr_c)
|
||||
b.requires_grad_(True)
|
||||
#th_c = th.matmul(a, th.t(b))
|
||||
#th_d = th.matmul(th.t(th_c), b)
|
||||
tr_c = triton.ops.dot(a, b, False, True)
|
||||
#tr_d = triton.ops.dot(tr_c, b, True, False)
|
||||
y = th.sum(tr_c)
|
||||
#print('backprop', y)
|
||||
y.backward()
|
||||
#print('backward done')
|
||||
print(b.grad)
|
||||
#th_d.backward()
|
||||
#print(a.grad)
|
||||
|
||||
|
||||
run_torch()
|
Reference in New Issue
Block a user