[PYTHON][OPS] Convolution: Some cleaning of Triton-C kernel
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
|
||||
N, C, K = 32, 8, 32
|
||||
H, W = 4, 4
|
||||
H, W = 16, 16
|
||||
R, S = 3, 3
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn(N, C, H, W).cuda()
|
||||
@@ -11,6 +11,5 @@ b = torch.ones(C, R, S, K).cuda()
|
||||
rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
tc = triton.ops.conv(a, b)
|
||||
print((rc - tc).abs().max())
|
||||
print((tc[:,:,0,0] - rc[:,:,0,0]).abs())
|
||||
#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max())
|
||||
#print(tc[31, 31,:,:])
|
Reference in New Issue
Block a user