11 lines
217 B
Python
11 lines
217 B
Python
![]() |
import torch
|
||
|
import triton
|
||
|
|
||
|
N, C, K = 32, 32, 32
|
||
|
H, W = 32, 32
|
||
|
R, S = 3, 3
|
||
|
a = torch.randn(N, C, H, W).cuda()
|
||
|
b = torch.randn(C, R, S, K).cuda()
|
||
|
#c = torch.nn.functional.conv2d(a, b)
|
||
|
c = triton.ops.conv(a, b)
|
||
|
print(c)
|