Files
triton/python/examples/conv.py

11 lines
217 B
Python
Raw Normal View History

2019-10-31 18:08:27 -04:00
import torch
import triton
N, C, K = 32, 32, 32
H, W = 32, 32
R, S = 3, 3
a = torch.randn(N, C, H, W).cuda()
b = torch.randn(C, R, S, K).cuda()
#c = torch.nn.functional.conv2d(a, b)
c = triton.ops.conv(a, b)
print(c)