'''LeNet in PyTorch.''' import torch.nn as nn import torch.nn.functional as F import triton class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 512, 3) self.conv2 = triton.Conv2d(512, 512, 1) self.fc1 = nn.Linear(512*7*7, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): out = F.relu(self.conv1(x)) out = F.max_pool2d(out, 2) out = F.relu(self.conv2(out)) out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out