[PYTHON][OPS] Bugfix in conv fprop

This commit is contained in:
Philippe Tillet
2019-11-01 00:43:02 -04:00
parent 739a8d9061
commit f4bbbbe5e4
6 changed files with 65 additions and 85 deletions

View File

@@ -1,11 +1,16 @@
import torch
import triton
N, C, K = 32, 32, 32
H, W = 32, 32
N, C, K = 32, 8, 32
H, W = 4, 4
R, S = 3, 3
torch.manual_seed(0)
a = torch.randn(N, C, H, W).cuda()
b = torch.randn(C, R, S, K).cuda()
#c = torch.nn.functional.conv2d(a, b)
c = triton.ops.conv(a, b)
print(c)
b = torch.ones(C, R, S, K).cuda()
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
tc = triton.ops.conv(a, b)
print((rc - tc).abs().max())
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
#print(tc[31, 31,:,:])