[dnn] Adding batchnorm
This commit is contained in:
@@ -5,6 +5,6 @@ if(${TORCH_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TORCH_LIBRARY_DIRS})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_library(torch_triton SHARED conv.cpp)
|
||||
add_library(torch_triton SHARED conv.cpp shift.cpp)
|
||||
target_link_libraries(torch_triton torch triton)
|
||||
endif()
|
||||
|
@@ -1,75 +0,0 @@
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include "triton/driver/device.h"
|
||||
#include <algorithm>
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
// helper function to print a tuple of any size
|
||||
template<class Tuple, std::size_t N>
|
||||
struct TuplePrinter {
|
||||
static void print(const Tuple& t)
|
||||
{
|
||||
TuplePrinter<Tuple, N-1>::print(t);
|
||||
std::cout << ", " << std::get<N-1>(t);
|
||||
}
|
||||
};
|
||||
|
||||
template<class Tuple>
|
||||
struct TuplePrinter<Tuple, 1> {
|
||||
static void print(const Tuple& t)
|
||||
{
|
||||
std::cout << std::get<0>(t);
|
||||
}
|
||||
};
|
||||
|
||||
template<class... Args>
|
||||
void print(const std::tuple<Args...>& t)
|
||||
{
|
||||
std::cout << "(";
|
||||
TuplePrinter<decltype(t), sizeof...(Args)>::print(t);
|
||||
std::cout << ")\n";
|
||||
}
|
||||
|
||||
|
@@ -56,7 +56,7 @@ torch::Tensor conv_common(
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
ty, has_bias)).first->second.get();
|
||||
"fp32", "fp32", ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
|
||||
|
@@ -1,145 +0,0 @@
|
||||
'''Train CIFAR10 with PyTorch.'''
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from models import *
|
||||
from utils import progress_bar
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
|
||||
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
|
||||
args = parser.parse_args()
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
best_acc = 0 # best test accuracy
|
||||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||
|
||||
# Data
|
||||
print('==> Preparing data..')
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
|
||||
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
|
||||
|
||||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||
|
||||
# Model
|
||||
print('==> Building model..')
|
||||
net = LeNet()
|
||||
# net = VGG('VGG19')
|
||||
# net = ResNet18()
|
||||
# net = PreActResNet18()
|
||||
# net = GoogLeNet()
|
||||
# net = DenseNet121()
|
||||
# net = ResNeXt29_2x64d()
|
||||
# net = MobileNet()
|
||||
# net = MobileNetV2()
|
||||
# net = DPN92()
|
||||
# net = ShuffleNetG2()
|
||||
# net = SENet18()
|
||||
# net = ShuffleNetV2(1)
|
||||
net = net.to(device)
|
||||
if device == 'cuda':
|
||||
net = torch.nn.DataParallel(net)
|
||||
cudnn.benchmark = False
|
||||
|
||||
if args.resume:
|
||||
# Load checkpoint.
|
||||
print('==> Resuming from checkpoint..')
|
||||
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
||||
checkpoint = torch.load('./checkpoint/ckpt.t7')
|
||||
net.load_state_dict(checkpoint['net'])
|
||||
best_acc = checkpoint['acc']
|
||||
start_epoch = checkpoint['epoch']
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
# Training
|
||||
def train(epoch):
|
||||
print('\nEpoch: %d' % epoch)
|
||||
net.train()
|
||||
train_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
for batch_idx, (inputs, targets) in enumerate(trainloader):
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
|
||||
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
||||
|
||||
def test(epoch):
|
||||
global best_acc
|
||||
net.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (inputs, targets) in enumerate(testloader):
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
|
||||
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
||||
|
||||
# Save checkpoint.
|
||||
acc = 100.*correct/total
|
||||
if acc > best_acc:
|
||||
print('Saving..')
|
||||
state = {
|
||||
'net': net.state_dict(),
|
||||
'acc': acc,
|
||||
'epoch': epoch,
|
||||
}
|
||||
if not os.path.isdir('checkpoint'):
|
||||
os.mkdir('checkpoint')
|
||||
torch.save(state, './checkpoint/ckpt.t7')
|
||||
best_acc = acc
|
||||
|
||||
|
||||
for epoch in range(start_epoch, start_epoch+200):
|
||||
train(epoch)
|
||||
test(epoch)
|
@@ -1,14 +0,0 @@
|
||||
from .vgg import *
|
||||
from .dpn import *
|
||||
from .lenet import *
|
||||
from .senet import *
|
||||
from .pnasnet import *
|
||||
from .densenet import *
|
||||
from .googlenet import *
|
||||
from .shufflenet import *
|
||||
from .shufflenetv2 import *
|
||||
from .resnet import *
|
||||
from .resnext import *
|
||||
from .preact_resnet import *
|
||||
from .mobilenet import *
|
||||
from .mobilenetv2 import *
|
@@ -1,107 +0,0 @@
|
||||
'''DenseNet in PyTorch.'''
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, in_planes, growth_rate):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(4*growth_rate)
|
||||
self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = torch.cat([out,x], 1)
|
||||
return out
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, in_planes, out_planes):
|
||||
super(Transition, self).__init__()
|
||||
self.bn = nn.BatchNorm2d(in_planes)
|
||||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(F.relu(self.bn(x)))
|
||||
out = F.avg_pool2d(out, 2)
|
||||
return out
|
||||
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
|
||||
super(DenseNet, self).__init__()
|
||||
self.growth_rate = growth_rate
|
||||
|
||||
num_planes = 2*growth_rate
|
||||
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
|
||||
num_planes += nblocks[0]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans1 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
|
||||
num_planes += nblocks[1]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans2 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
|
||||
num_planes += nblocks[2]*growth_rate
|
||||
out_planes = int(math.floor(num_planes*reduction))
|
||||
self.trans3 = Transition(num_planes, out_planes)
|
||||
num_planes = out_planes
|
||||
|
||||
self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
|
||||
num_planes += nblocks[3]*growth_rate
|
||||
|
||||
self.bn = nn.BatchNorm2d(num_planes)
|
||||
self.linear = nn.Linear(num_planes, num_classes)
|
||||
|
||||
def _make_dense_layers(self, block, in_planes, nblock):
|
||||
layers = []
|
||||
for i in range(nblock):
|
||||
layers.append(block(in_planes, self.growth_rate))
|
||||
in_planes += self.growth_rate
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.trans1(self.dense1(out))
|
||||
out = self.trans2(self.dense2(out))
|
||||
out = self.trans3(self.dense3(out))
|
||||
out = self.dense4(out)
|
||||
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
def DenseNet121():
|
||||
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
|
||||
|
||||
def DenseNet169():
|
||||
return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
|
||||
|
||||
def DenseNet201():
|
||||
return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
|
||||
|
||||
def DenseNet161():
|
||||
return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
|
||||
|
||||
def densenet_cifar():
|
||||
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
|
||||
|
||||
def test():
|
||||
net = densenet_cifar()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y)
|
||||
|
||||
# test()
|
@@ -1,98 +0,0 @@
|
||||
'''Dual Path Networks in PyTorch.'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.out_planes = out_planes
|
||||
self.dense_depth = dense_depth
|
||||
|
||||
self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(in_planes)
|
||||
self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_planes+dense_depth)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if first_layer:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_planes+dense_depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
x = self.shortcut(x)
|
||||
d = self.out_planes
|
||||
out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class DPN(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(DPN, self).__init__()
|
||||
in_planes, out_planes = cfg['in_planes'], cfg['out_planes']
|
||||
num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth']
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.last_planes = 64
|
||||
self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1)
|
||||
self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2)
|
||||
self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2)
|
||||
self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2)
|
||||
self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10)
|
||||
|
||||
def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for i,stride in enumerate(strides):
|
||||
layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0))
|
||||
self.last_planes = out_planes + (i+2) * dense_depth
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def DPN26():
|
||||
cfg = {
|
||||
'in_planes': (96,192,384,768),
|
||||
'out_planes': (256,512,1024,2048),
|
||||
'num_blocks': (2,2,2,2),
|
||||
'dense_depth': (16,32,24,128)
|
||||
}
|
||||
return DPN(cfg)
|
||||
|
||||
def DPN92():
|
||||
cfg = {
|
||||
'in_planes': (96,192,384,768),
|
||||
'out_planes': (256,512,1024,2048),
|
||||
'num_blocks': (3,4,20,3),
|
||||
'dense_depth': (16,32,24,128)
|
||||
}
|
||||
return DPN(cfg)
|
||||
|
||||
|
||||
def test():
|
||||
net = DPN92()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y)
|
||||
|
||||
# test()
|
@@ -1,107 +0,0 @@
|
||||
'''GoogLeNet with PyTorch.'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Inception(nn.Module):
|
||||
def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
|
||||
super(Inception, self).__init__()
|
||||
# 1x1 conv branch
|
||||
self.b1 = nn.Sequential(
|
||||
nn.Conv2d(in_planes, n1x1, kernel_size=1),
|
||||
nn.BatchNorm2d(n1x1),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
# 1x1 conv -> 3x3 conv branch
|
||||
self.b2 = nn.Sequential(
|
||||
nn.Conv2d(in_planes, n3x3red, kernel_size=1),
|
||||
nn.BatchNorm2d(n3x3red),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(n3x3),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
# 1x1 conv -> 5x5 conv branch
|
||||
self.b3 = nn.Sequential(
|
||||
nn.Conv2d(in_planes, n5x5red, kernel_size=1),
|
||||
nn.BatchNorm2d(n5x5red),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(n5x5),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(n5x5),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
# 3x3 pool -> 1x1 conv branch
|
||||
self.b4 = nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=1, padding=1),
|
||||
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
|
||||
nn.BatchNorm2d(pool_planes),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y1 = self.b1(x)
|
||||
y2 = self.b2(x)
|
||||
y3 = self.b3(x)
|
||||
y4 = self.b4(x)
|
||||
return torch.cat([y1,y2,y3,y4], 1)
|
||||
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(GoogLeNet, self).__init__()
|
||||
self.pre_layers = nn.Sequential(
|
||||
nn.Conv2d(3, 192, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(192),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
|
||||
self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
|
||||
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
|
||||
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
|
||||
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
|
||||
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
|
||||
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
|
||||
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
|
||||
|
||||
self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
|
||||
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.linear = nn.Linear(1024, 10)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.pre_layers(x)
|
||||
out = self.a3(out)
|
||||
out = self.b3(out)
|
||||
out = self.maxpool(out)
|
||||
out = self.a4(out)
|
||||
out = self.b4(out)
|
||||
out = self.c4(out)
|
||||
out = self.d4(out)
|
||||
out = self.e4(out)
|
||||
out = self.maxpool(out)
|
||||
out = self.a5(out)
|
||||
out = self.b5(out)
|
||||
out = self.avgpool(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
net = GoogLeNet()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,24 +0,0 @@
|
||||
'''LeNet in PyTorch.'''
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
|
||||
class LeNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 512, 3)
|
||||
self.conv2 = triton.Conv2d(512, 512, 1)
|
||||
self.fc1 = nn.Linear(512*7*7, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.conv1(x))
|
||||
out = F.max_pool2d(out, 2)
|
||||
out = F.relu(self.conv2(out))
|
||||
out = F.max_pool2d(out, 2)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = F.relu(self.fc1(out))
|
||||
out = F.relu(self.fc2(out))
|
||||
out = self.fc3(out)
|
||||
return out
|
@@ -1,61 +0,0 @@
|
||||
'''MobileNet in PyTorch.
|
||||
|
||||
See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
|
||||
for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''Depthwise conv + Pointwise conv'''
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(Block, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
return out
|
||||
|
||||
|
||||
class MobileNet(nn.Module):
|
||||
# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
|
||||
cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024]
|
||||
|
||||
def __init__(self, num_classes=10):
|
||||
super(MobileNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.layers = self._make_layers(in_planes=32)
|
||||
self.linear = nn.Linear(1024, num_classes)
|
||||
|
||||
def _make_layers(self, in_planes):
|
||||
layers = []
|
||||
for x in self.cfg:
|
||||
out_planes = x if isinstance(x, int) else x[0]
|
||||
stride = 1 if isinstance(x, int) else x[1]
|
||||
layers.append(Block(in_planes, out_planes, stride))
|
||||
in_planes = out_planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layers(out)
|
||||
out = F.avg_pool2d(out, 2)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNet()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,86 +0,0 @@
|
||||
'''MobileNetV2 in PyTorch.
|
||||
|
||||
See the paper "Inverted Residuals and Linear Bottlenecks:
|
||||
Mobile Networks for Classification, Detection and Segmentation" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''expand + depthwise + pointwise'''
|
||||
def __init__(self, in_planes, out_planes, expansion, stride):
|
||||
super(Block, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
planes = expansion * in_planes
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 1 and in_planes != out_planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = out + self.shortcut(x) if self.stride==1 else out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
# (expansion, out_planes, num_blocks, stride)
|
||||
cfg = [(1, 16, 1, 1),
|
||||
(6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
|
||||
(6, 32, 3, 2),
|
||||
(6, 64, 4, 2),
|
||||
(6, 96, 3, 1),
|
||||
(6, 160, 3, 2),
|
||||
(6, 320, 1, 1)]
|
||||
|
||||
def __init__(self, num_classes=10):
|
||||
super(MobileNetV2, self).__init__()
|
||||
# NOTE: change conv1 stride 2 -> 1 for CIFAR10
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.layers = self._make_layers(in_planes=32)
|
||||
self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(1280)
|
||||
self.linear = nn.Linear(1280, num_classes)
|
||||
|
||||
def _make_layers(self, in_planes):
|
||||
layers = []
|
||||
for expansion, out_planes, num_blocks, stride in self.cfg:
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
for stride in strides:
|
||||
layers.append(Block(in_planes, out_planes, expansion, stride))
|
||||
in_planes = out_planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layers(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
# NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def test():
|
||||
net = MobileNetV2()
|
||||
x = torch.randn(2,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,125 +0,0 @@
|
||||
'''PNASNet in PyTorch.
|
||||
|
||||
Paper: Progressive Neural Architecture Search
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
'''Separable Convolution.'''
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride):
|
||||
super(SepConv, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size, stride,
|
||||
padding=(kernel_size-1)//2,
|
||||
bias=False, groups=in_planes)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn1(self.conv1(x))
|
||||
|
||||
|
||||
class CellA(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(CellA, self).__init__()
|
||||
self.stride = stride
|
||||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
|
||||
if stride==2:
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
y1 = self.sep_conv1(x)
|
||||
y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
|
||||
if self.stride==2:
|
||||
y2 = self.bn1(self.conv1(y2))
|
||||
return F.relu(y1+y2)
|
||||
|
||||
class CellB(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(CellB, self).__init__()
|
||||
self.stride = stride
|
||||
# Left branch
|
||||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
|
||||
self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
|
||||
# Right branch
|
||||
self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
|
||||
if stride==2:
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
# Reduce channels
|
||||
self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
# Left branch
|
||||
y1 = self.sep_conv1(x)
|
||||
y2 = self.sep_conv2(x)
|
||||
# Right branch
|
||||
y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
|
||||
if self.stride==2:
|
||||
y3 = self.bn1(self.conv1(y3))
|
||||
y4 = self.sep_conv3(x)
|
||||
# Concat & reduce channels
|
||||
b1 = F.relu(y1+y2)
|
||||
b2 = F.relu(y3+y4)
|
||||
y = torch.cat([b1,b2], 1)
|
||||
return F.relu(self.bn2(self.conv2(y)))
|
||||
|
||||
class PNASNet(nn.Module):
|
||||
def __init__(self, cell_type, num_cells, num_planes):
|
||||
super(PNASNet, self).__init__()
|
||||
self.in_planes = num_planes
|
||||
self.cell_type = cell_type
|
||||
|
||||
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(num_planes)
|
||||
|
||||
self.layer1 = self._make_layer(num_planes, num_cells=6)
|
||||
self.layer2 = self._downsample(num_planes*2)
|
||||
self.layer3 = self._make_layer(num_planes*2, num_cells=6)
|
||||
self.layer4 = self._downsample(num_planes*4)
|
||||
self.layer5 = self._make_layer(num_planes*4, num_cells=6)
|
||||
|
||||
self.linear = nn.Linear(num_planes*4, 10)
|
||||
|
||||
def _make_layer(self, planes, num_cells):
|
||||
layers = []
|
||||
for _ in range(num_cells):
|
||||
layers.append(self.cell_type(self.in_planes, planes, stride=1))
|
||||
self.in_planes = planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _downsample(self, planes):
|
||||
layer = self.cell_type(self.in_planes, planes, stride=2)
|
||||
self.in_planes = planes
|
||||
return layer
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = self.layer5(out)
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = self.linear(out.view(out.size(0), -1))
|
||||
return out
|
||||
|
||||
|
||||
def PNASNetA():
|
||||
return PNASNet(CellA, num_cells=6, num_planes=44)
|
||||
|
||||
def PNASNetB():
|
||||
return PNASNet(CellB, num_cells=6, num_planes=32)
|
||||
|
||||
|
||||
def test():
|
||||
net = PNASNetB()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y)
|
||||
|
||||
# test()
|
@@ -1,118 +0,0 @@
|
||||
'''Pre-activation ResNet in PyTorch.
|
||||
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Identity Mappings in Deep Residual Networks. arXiv:1603.05027
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PreActBlock(nn.Module):
|
||||
'''Pre-activation version of the BasicBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(PreActBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
'''Pre-activation version of the original Bottleneck module.'''
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(PreActBottleneck, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = self.conv3(F.relu(self.bn3(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class PreActResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(PreActResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def PreActResNet18():
|
||||
return PreActResNet(PreActBlock, [2,2,2,2])
|
||||
|
||||
def PreActResNet34():
|
||||
return PreActResNet(PreActBlock, [3,4,6,3])
|
||||
|
||||
def PreActResNet50():
|
||||
return PreActResNet(PreActBottleneck, [3,4,6,3])
|
||||
|
||||
def PreActResNet101():
|
||||
return PreActResNet(PreActBottleneck, [3,4,23,3])
|
||||
|
||||
def PreActResNet152():
|
||||
return PreActResNet(PreActBottleneck, [3,8,36,3])
|
||||
|
||||
|
||||
def test():
|
||||
net = PreActResNet18()
|
||||
y = net((torch.randn(1,3,32,32)))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,121 +0,0 @@
|
||||
'''ResNet in PyTorch.
|
||||
|
||||
For Pre-activation ResNet, see 'preact_resnet.py'.
|
||||
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = triton.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = triton.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
triton.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = triton.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = triton.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = triton.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
triton.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = triton.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNet18():
|
||||
return ResNet(BasicBlock, [2,2,2,2])
|
||||
|
||||
def ResNet34():
|
||||
return ResNet(BasicBlock, [3,4,6,3])
|
||||
|
||||
def ResNet50():
|
||||
return ResNet(Bottleneck, [3,4,6,3])
|
||||
|
||||
def ResNet101():
|
||||
return ResNet(Bottleneck, [3,4,23,3])
|
||||
|
||||
def ResNet152():
|
||||
return ResNet(Bottleneck, [3,8,36,3])
|
||||
|
||||
|
||||
def test():
|
||||
net = ResNet18()
|
||||
y = net(torch.randn(1,3,32,32))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,95 +0,0 @@
|
||||
'''ResNeXt in PyTorch.
|
||||
|
||||
See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
'''Grouped convolution block.'''
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
|
||||
super(Block, self).__init__()
|
||||
group_width = cardinality * bottleneck_width
|
||||
self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(group_width)
|
||||
self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(group_width)
|
||||
self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion*group_width:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*group_width)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNeXt(nn.Module):
|
||||
def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10):
|
||||
super(ResNeXt, self).__init__()
|
||||
self.cardinality = cardinality
|
||||
self.bottleneck_width = bottleneck_width
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(num_blocks[0], 1)
|
||||
self.layer2 = self._make_layer(num_blocks[1], 2)
|
||||
self.layer3 = self._make_layer(num_blocks[2], 2)
|
||||
# self.layer4 = self._make_layer(num_blocks[3], 2)
|
||||
self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes)
|
||||
|
||||
def _make_layer(self, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
|
||||
self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
|
||||
# Increase bottleneck_width by 2 after each stage.
|
||||
self.bottleneck_width *= 2
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
# out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNeXt29_2x64d():
|
||||
return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64)
|
||||
|
||||
def ResNeXt29_4x64d():
|
||||
return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64)
|
||||
|
||||
def ResNeXt29_8x64d():
|
||||
return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64)
|
||||
|
||||
def ResNeXt29_32x4d():
|
||||
return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4)
|
||||
|
||||
def test_resnext():
|
||||
net = ResNeXt29_2x64d()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test_resnext()
|
@@ -1,121 +0,0 @@
|
||||
'''SENet in PyTorch.
|
||||
|
||||
SENet is the winner of ImageNet-2017. The paper is not released yet.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes)
|
||||
)
|
||||
|
||||
# SE layers
|
||||
self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
|
||||
self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
|
||||
# Squeeze
|
||||
w = F.avg_pool2d(out, out.size(2))
|
||||
w = F.relu(self.fc1(w))
|
||||
w = F.sigmoid(self.fc2(w))
|
||||
# Excitation
|
||||
out = out * w # New broadcasting feature from v0.2!
|
||||
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class PreActBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(PreActBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
# SE layers
|
||||
self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
|
||||
self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
|
||||
# Squeeze
|
||||
w = F.avg_pool2d(out, out.size(2))
|
||||
w = F.relu(self.fc1(w))
|
||||
w = F.sigmoid(self.fc2(w))
|
||||
# Excitation
|
||||
out = out * w
|
||||
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
|
||||
class SENet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(SENet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def SENet18():
|
||||
return SENet(PreActBlock, [2,2,2,2])
|
||||
|
||||
|
||||
def test():
|
||||
net = SENet18()
|
||||
y = net(torch.randn(1,3,32,32))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
@@ -1,109 +0,0 @@
|
||||
'''ShuffleNet in PyTorch.
|
||||
|
||||
See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ShuffleBlock(nn.Module):
|
||||
def __init__(self, groups):
|
||||
super(ShuffleBlock, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
|
||||
N,C,H,W = x.size()
|
||||
g = self.groups
|
||||
return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W)
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, groups):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
mid_planes = out_planes/4
|
||||
g = 1 if in_planes==24 else groups
|
||||
self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_planes)
|
||||
self.shuffle1 = ShuffleBlock(groups=g)
|
||||
self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_planes)
|
||||
self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride == 2:
|
||||
self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.shuffle1(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
res = self.shortcut(x)
|
||||
out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res)
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNet(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(ShuffleNet, self).__init__()
|
||||
out_planes = cfg['out_planes']
|
||||
num_blocks = cfg['num_blocks']
|
||||
groups = cfg['groups']
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(24)
|
||||
self.in_planes = 24
|
||||
self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
|
||||
self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
|
||||
self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
|
||||
self.linear = nn.Linear(out_planes[2], 10)
|
||||
|
||||
def _make_layer(self, out_planes, num_blocks, groups):
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
stride = 2 if i == 0 else 1
|
||||
cat_planes = self.in_planes if i == 0 else 0
|
||||
layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups))
|
||||
self.in_planes = out_planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ShuffleNetG2():
|
||||
cfg = {
|
||||
'out_planes': [200,400,800],
|
||||
'num_blocks': [4,8,4],
|
||||
'groups': 2
|
||||
}
|
||||
return ShuffleNet(cfg)
|
||||
|
||||
def ShuffleNetG3():
|
||||
cfg = {
|
||||
'out_planes': [240,480,960],
|
||||
'num_blocks': [4,8,4],
|
||||
'groups': 3
|
||||
}
|
||||
return ShuffleNet(cfg)
|
||||
|
||||
|
||||
def test():
|
||||
net = ShuffleNetG2()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y)
|
||||
|
||||
# test()
|
@@ -1,162 +0,0 @@
|
||||
'''ShuffleNetV2 in PyTorch.
|
||||
|
||||
See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ShuffleBlock(nn.Module):
|
||||
def __init__(self, groups=2):
|
||||
super(ShuffleBlock, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
|
||||
N, C, H, W = x.size()
|
||||
g = self.groups
|
||||
return x.view(N, g, C/g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
|
||||
|
||||
|
||||
class SplitBlock(nn.Module):
|
||||
def __init__(self, ratio):
|
||||
super(SplitBlock, self).__init__()
|
||||
self.ratio = ratio
|
||||
|
||||
def forward(self, x):
|
||||
c = int(x.size(1) * self.ratio)
|
||||
return x[:, :c, :, :], x[:, c:, :, :]
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_channels, split_ratio=0.5):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.split = SplitBlock(split_ratio)
|
||||
in_channels = int(in_channels * split_ratio)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv2 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels)
|
||||
self.conv3 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(in_channels)
|
||||
self.shuffle = ShuffleBlock()
|
||||
|
||||
def forward(self, x):
|
||||
x1, x2 = self.split(x)
|
||||
out = F.relu(self.bn1(self.conv1(x2)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = F.relu(self.bn3(self.conv3(out)))
|
||||
out = torch.cat([x1, out], 1)
|
||||
out = self.shuffle(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(DownBlock, self).__init__()
|
||||
mid_channels = out_channels // 2
|
||||
# left
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv2 = nn.Conv2d(in_channels, mid_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_channels)
|
||||
# right
|
||||
self.conv3 = nn.Conv2d(in_channels, mid_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(mid_channels)
|
||||
self.conv4 = nn.Conv2d(mid_channels, mid_channels,
|
||||
kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
|
||||
self.bn4 = nn.BatchNorm2d(mid_channels)
|
||||
self.conv5 = nn.Conv2d(mid_channels, mid_channels,
|
||||
kernel_size=1, bias=False)
|
||||
self.bn5 = nn.BatchNorm2d(mid_channels)
|
||||
|
||||
self.shuffle = ShuffleBlock()
|
||||
|
||||
def forward(self, x):
|
||||
# left
|
||||
out1 = self.bn1(self.conv1(x))
|
||||
out1 = F.relu(self.bn2(self.conv2(out1)))
|
||||
# right
|
||||
out2 = F.relu(self.bn3(self.conv3(x)))
|
||||
out2 = self.bn4(self.conv4(out2))
|
||||
out2 = F.relu(self.bn5(self.conv5(out2)))
|
||||
# concat
|
||||
out = torch.cat([out1, out2], 1)
|
||||
out = self.shuffle(out)
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Module):
|
||||
def __init__(self, net_size):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
out_channels = configs[net_size]['out_channels']
|
||||
num_blocks = configs[net_size]['num_blocks']
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
|
||||
stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(24)
|
||||
self.in_channels = 24
|
||||
self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
|
||||
self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
|
||||
self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
|
||||
self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels[3])
|
||||
self.linear = nn.Linear(out_channels[3], 10)
|
||||
|
||||
def _make_layer(self, out_channels, num_blocks):
|
||||
layers = [DownBlock(self.in_channels, out_channels)]
|
||||
for i in range(num_blocks):
|
||||
layers.append(BasicBlock(out_channels))
|
||||
self.in_channels = out_channels
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
# out = F.max_pool2d(out, 3, stride=2, padding=1)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
configs = {
|
||||
0.5: {
|
||||
'out_channels': (48, 96, 192, 1024),
|
||||
'num_blocks': (3, 7, 3)
|
||||
},
|
||||
|
||||
1: {
|
||||
'out_channels': (116, 232, 464, 1024),
|
||||
'num_blocks': (3, 7, 3)
|
||||
},
|
||||
1.5: {
|
||||
'out_channels': (176, 352, 704, 1024),
|
||||
'num_blocks': (3, 7, 3)
|
||||
},
|
||||
2: {
|
||||
'out_channels': (224, 488, 976, 2048),
|
||||
'num_blocks': (3, 7, 3)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test():
|
||||
net = ShuffleNetV2(net_size=0.5)
|
||||
x = torch.randn(3, 3, 32, 32)
|
||||
y = net(x)
|
||||
print(y.shape)
|
||||
|
||||
|
||||
# test()
|
@@ -1,47 +0,0 @@
|
||||
'''VGG11/13/16/19 in Pytorch.'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
|
||||
cfg = {
|
||||
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
|
||||
class VGG(nn.Module):
|
||||
def __init__(self, vgg_name):
|
||||
super(VGG, self).__init__()
|
||||
self.features = self._make_layers(cfg[vgg_name])
|
||||
self.classifier = nn.Linear(512, 10)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.features(x)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.classifier(out)
|
||||
return out
|
||||
|
||||
def _make_layers(self, cfg):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for x in cfg:
|
||||
if x == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
layers += [triton.Conv2d(in_channels, x, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(x),
|
||||
nn.ReLU(inplace=True)]
|
||||
in_channels = x
|
||||
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def test():
|
||||
net = VGG('VGG11')
|
||||
x = torch.randn(2,3,32,32)
|
||||
y = net(x)
|
||||
print(y.size())
|
||||
|
||||
# test()
|
114
examples/python/pytorch/shift.cpp
Normal file
114
examples/python/pytorch/shift.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t*, int32_t*,
|
||||
triton::dnn::shift::type, bool> shift_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_shift_stream;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_shift_jit;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_shift_config;
|
||||
|
||||
torch::Tensor shift_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t F,
|
||||
std::vector<int32_t> shift_h, std::vector<int32_t> shift_w,
|
||||
triton::dnn::shift::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
if(m_shift_stream.find(custream) == m_shift_stream.end())
|
||||
stream = m_shift_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_shift_stream.at(custream).get();
|
||||
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
shift_key_t key = {B, C, D, H, W, T, R, S, F, shift_h.data(), shift_w.data(), ty, has_bias};
|
||||
triton::dnn::shift* configuration;
|
||||
if(m_shift_config.find(key) == m_shift_config.end())
|
||||
configuration = m_shift_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R, S, F,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_shift_config.at(key).get();
|
||||
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given shiftolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
||||
jit->add_module("shift", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_shift_jit.at(key).get();
|
||||
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
return torchc;
|
||||
}
|
@@ -1,124 +0,0 @@
|
||||
'''Some helper functions for PyTorch, including:
|
||||
- get_mean_and_std: calculate the mean and std value of dataset.
|
||||
- msr_init: net parameter initialization.
|
||||
- progress_bar: progress bar mimic xlua.progress.
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def get_mean_and_std(dataset):
|
||||
'''Compute the mean and std value of dataset.'''
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
|
||||
mean = torch.zeros(3)
|
||||
std = torch.zeros(3)
|
||||
print('==> Computing mean and std..')
|
||||
for inputs, targets in dataloader:
|
||||
for i in range(3):
|
||||
mean[i] += inputs[:,i,:,:].mean()
|
||||
std[i] += inputs[:,i,:,:].std()
|
||||
mean.div_(len(dataset))
|
||||
std.div_(len(dataset))
|
||||
return mean, std
|
||||
|
||||
def init_params(net):
|
||||
'''Init layer parameters.'''
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal(m.weight, mode='fan_out')
|
||||
if m.bias:
|
||||
init.constant(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant(m.weight, 1)
|
||||
init.constant(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.normal(m.weight, std=1e-3)
|
||||
if m.bias:
|
||||
init.constant(m.bias, 0)
|
||||
|
||||
|
||||
_, term_width = os.popen('stty size', 'r').read().split()
|
||||
term_width = int(term_width)
|
||||
|
||||
TOTAL_BAR_LENGTH = 65.
|
||||
last_time = time.time()
|
||||
begin_time = last_time
|
||||
def progress_bar(current, total, msg=None):
|
||||
global last_time, begin_time
|
||||
if current == 0:
|
||||
begin_time = time.time() # Reset for new bar.
|
||||
|
||||
cur_len = int(TOTAL_BAR_LENGTH*current/total)
|
||||
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
||||
|
||||
sys.stdout.write(' [')
|
||||
for i in range(cur_len):
|
||||
sys.stdout.write('=')
|
||||
sys.stdout.write('>')
|
||||
for i in range(rest_len):
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.write(']')
|
||||
|
||||
cur_time = time.time()
|
||||
step_time = cur_time - last_time
|
||||
last_time = cur_time
|
||||
tot_time = cur_time - begin_time
|
||||
|
||||
L = []
|
||||
L.append(' Step: %s' % format_time(step_time))
|
||||
L.append(' | Tot: %s' % format_time(tot_time))
|
||||
if msg:
|
||||
L.append(' | ' + msg)
|
||||
|
||||
msg = ''.join(L)
|
||||
sys.stdout.write(msg)
|
||||
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
|
||||
sys.stdout.write(' ')
|
||||
|
||||
# Go back to the center of the bar.
|
||||
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
|
||||
sys.stdout.write('\b')
|
||||
sys.stdout.write(' %d/%d ' % (current+1, total))
|
||||
|
||||
if current < total-1:
|
||||
sys.stdout.write('\r')
|
||||
else:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
def format_time(seconds):
|
||||
days = int(seconds / 3600/24)
|
||||
seconds = seconds - days*3600*24
|
||||
hours = int(seconds / 3600)
|
||||
seconds = seconds - hours*3600
|
||||
minutes = int(seconds / 60)
|
||||
seconds = seconds - minutes*60
|
||||
secondsf = int(seconds)
|
||||
seconds = seconds - secondsf
|
||||
millis = int(seconds*1000)
|
||||
|
||||
f = ''
|
||||
i = 1
|
||||
if days > 0:
|
||||
f += str(days) + 'D'
|
||||
i += 1
|
||||
if hours > 0 and i <= 2:
|
||||
f += str(hours) + 'h'
|
||||
i += 1
|
||||
if minutes > 0 and i <= 2:
|
||||
f += str(minutes) + 'm'
|
||||
i += 1
|
||||
if secondsf > 0 and i <= 2:
|
||||
f += str(secondsf) + 's'
|
||||
i += 1
|
||||
if millis > 0 and i <= 2:
|
||||
f += str(millis) + 'ms'
|
||||
i += 1
|
||||
if f == '':
|
||||
f = '0ms'
|
||||
return f
|
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp shift.cpp)
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||
|
174
examples/python/tensorflow/batchnorm.cpp
Normal file
174
examples/python/tensorflow/batchnorm.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
class BatchnormForwardOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchnormForwardOp(OpKernelConstruction* context): OpKernel(context) {
|
||||
context->GetAttr("eps", &eps_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& x = context->input(0);
|
||||
const Tensor& g = context->input(1);
|
||||
const Tensor& b = context->input(2);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* y = nullptr;
|
||||
Tensor* m = nullptr;
|
||||
Tensor* v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tb(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer ty(ctx, (CUdeviceptr)y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
|
||||
}
|
||||
|
||||
private:
|
||||
float eps_;
|
||||
};
|
||||
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchnormForward").Device(DEVICE_GPU), BatchnormForwardOp);
|
||||
REGISTER_OP("BatchnormForward")
|
||||
.Input("x: T")
|
||||
.Input("g: float")
|
||||
.Input("b: float")
|
||||
.Output("y: T")
|
||||
.Output("m: float")
|
||||
.Output("v: float")
|
||||
.Attr("T: {float}")
|
||||
.Attr("eps: float")
|
||||
.SetShapeFn([](InferenceContext* ctx) {
|
||||
ctx->set_output(0, ctx->input(0));
|
||||
ctx->set_output(1, ctx->input(1));
|
||||
ctx->set_output(2, ctx->input(1));
|
||||
return Status::OK();
|
||||
})
|
||||
;
|
||||
|
||||
|
||||
class BatchnormBackwardOp : public OpKernel {
|
||||
public:
|
||||
explicit BatchnormBackwardOp(OpKernelConstruction* context): OpKernel(context) {
|
||||
context->GetAttr("eps", &eps_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& dy = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
const Tensor& g = context->input(2);
|
||||
const Tensor& m = context->input(3);
|
||||
const Tensor& v = context->input(4);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* dx = nullptr;
|
||||
Tensor* dg = nullptr;
|
||||
Tensor* db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tdy(ctx, (CUdeviceptr)dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
|
||||
}
|
||||
|
||||
private:
|
||||
float eps_;
|
||||
};
|
||||
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchnormBackward").Device(DEVICE_GPU), BatchnormBackwardOp);
|
||||
REGISTER_OP("BatchnormBackward")
|
||||
.Input("dy: TY")
|
||||
.Input("x: TX")
|
||||
.Input("g: float")
|
||||
.Input("m: float")
|
||||
.Input("v: float")
|
||||
.Output("dx: TY")
|
||||
.Output("dg: float")
|
||||
.Output("db: float")
|
||||
.Attr("TX: {float}")
|
||||
.Attr("TY: {float}")
|
||||
.Attr("eps: float")
|
||||
.SetShapeFn([](InferenceContext* ctx) {
|
||||
ctx->set_output(0, ctx->input(1));
|
||||
ctx->set_output(1, ctx->input(2));
|
||||
ctx->set_output(2, ctx->input(2));
|
||||
return Status::OK();
|
||||
})
|
||||
;
|
@@ -65,8 +65,6 @@ public:
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
|
||||
// allocate output
|
@@ -56,8 +56,8 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 1, 32, 8, 6
|
||||
R, S, F = 3, 3, 16
|
||||
B, C, H, W = 16, 1024, 8, 8
|
||||
R, S, F = 3, 3, 1024
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
@@ -65,8 +65,6 @@ def run_shift():
|
||||
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_h = np.ones(C, dtype=np.int32)
|
||||
#hshift_w = np.ones(C, dtype=np.int32)
|
||||
print(hshift_h)
|
||||
print(hshift_w)
|
||||
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# Reference
|
||||
ha = np.random.rand(C, H, W, B)
|
||||
@@ -74,16 +72,36 @@ def run_shift():
|
||||
#ha = np.ones((C, H, W, B), dtype=np.int32)
|
||||
#hb = np.ones((C, F), dtype=np.int32)
|
||||
sess = tf.InteractiveSession()
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
extra_feed_dict={a: ha, b: hb})
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
#grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
# extra_feed_dict = {a: ha, b: hb})
|
||||
#dw_t, dw_n = grads[1]
|
||||
#dx_t, dx_n = grads[0]
|
||||
#print(np.max(np.abs(dw_t - dw_n)))
|
||||
#print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
#print(result)
|
||||
|
||||
run_shift()
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 16, 16, 16
|
||||
np.random.seed(0)
|
||||
# Placeholders
|
||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
g = tf.placeholder(tf.float32, shape=[C])
|
||||
b = tf.placeholder(tf.float32, shape=[C])
|
||||
# Feed values
|
||||
hx = np.random.rand(C, H, W, B)
|
||||
hg = np.random.rand(C)
|
||||
hb = np.random.rand(C)
|
||||
# batchnorm
|
||||
y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
print(hx.sum(axis=(1,2,3)))
|
||||
print(result[1])
|
||||
|
||||
run_batchnorm()
|
||||
|
@@ -125,7 +125,7 @@ public:
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = false;
|
||||
bool autotune = true;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
|
83
include/triton/dnn/batchnorm.h
Normal file
83
include/triton/dnn/batchnorm.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/* Copyright 2015-2019 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_BATCHNORM_H
|
||||
#define TDL_INCLUDE_DNN_BATCHNORM_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class batchnorm_forward {
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_forward(int C, int D, int H, int W, int B, std::string ty = "fp32");
|
||||
// enqueue
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *b,
|
||||
size_t TM, size_t nthreads);
|
||||
// triton-c source code
|
||||
void src(std::ostream &os);
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
};
|
||||
|
||||
class batchnorm_backward {
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32");
|
||||
// enqueue
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t TM, size_t nthreads);
|
||||
// triton-c source code
|
||||
void src(std::ostream &os);
|
||||
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
int32_t D_;
|
||||
int32_t H_;
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -131,6 +131,7 @@ public:
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
|
@@ -581,6 +581,15 @@ public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
private:
|
||||
reduce_inst(value* arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
|
@@ -167,6 +167,14 @@ private:
|
||||
node* arg_;
|
||||
};
|
||||
|
||||
class reduce_expression: public builtin_expression{
|
||||
public:
|
||||
reduce_expression(node *arg): arg_(arg) {}
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
node* arg_;
|
||||
};
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
public:
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -125,6 +125,7 @@ builtin_expression
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
|
||||
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||
|
@@ -49,6 +49,7 @@ using triton::lang::return_void;
|
||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
||||
"dot" { return return_impl(DOT, yytext); }
|
||||
"max" { return return_impl(MAX, yytext); }
|
||||
"min" { return return_impl(MIN, yytext); }
|
||||
|
@@ -773,6 +773,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)});
|
||||
return;
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
|
||||
Value *partial = nullptr;
|
||||
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
// reduce within thread
|
||||
op->for_each([&](indices_t idx){
|
||||
Value *current = op->get_value(idx);
|
||||
if(partial == nullptr)
|
||||
partial = current;
|
||||
else
|
||||
partial = builder.CreateFAdd(partial, current);
|
||||
});
|
||||
// reduce within warp
|
||||
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
|
||||
for (int i = 16; i > 0; i >>= 1){
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0x1f), partial,
|
||||
builder.getInt32(i), builder.getInt32(0xffffffff)});
|
||||
partial = builder.CreateFAdd(partial, rhs);
|
||||
}
|
||||
// reduce within block
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
|
||||
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
|
||||
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
|
||||
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
||||
|
||||
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
|
||||
partial_reduce_do, partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_do);
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
|
||||
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
|
||||
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, warp_id);
|
||||
builder.CreateStore(partial, write_ptr);
|
||||
builder.CreateBr(partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_done);
|
||||
// Final reduction with the first warp
|
||||
tgt_->add_barrier(module, builder);
|
||||
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
|
||||
BasicBlock *final_reduce_done = BasicBlock::Create(ctx, "final_reduce_done", fn);
|
||||
builder.CreateCondBr(builder.CreateICmpEQ(warp_id, builder.getInt32(0)),
|
||||
final_reduce_do, final_reduce_done);
|
||||
builder.SetInsertPoint(final_reduce_do);
|
||||
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
|
||||
Value *result = builder.CreateLoad(read_ptr);
|
||||
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
|
||||
Value *rhs = builder.CreateCall(shfl, {result, builder.getInt32(i),
|
||||
builder.getInt32(0x1f), builder.getInt32(0xffffffff)});
|
||||
builder.CreateFAdd(result, rhs);
|
||||
}
|
||||
builder.CreateStore(result, read_ptr);
|
||||
builder.CreateBr(final_reduce_done);
|
||||
builder.SetInsertPoint(final_reduce_done);
|
||||
tgt_->add_barrier(module, builder);
|
||||
vmap_[ins] = builder.CreateLoad(sh_mem_ptr);
|
||||
return;
|
||||
}
|
||||
tile *ti = tmap_[ins];
|
||||
distributed_tile* result = (distributed_tile*)ti;
|
||||
if(!ins->get_type()->is_tile_ty())
|
||||
|
@@ -29,6 +29,8 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
}
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
if(dynamic_cast<ir::reduce_inst*>(x))
|
||||
return 32;
|
||||
unsigned result = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
|
@@ -40,6 +40,8 @@ inline bool get_is_shared(ir::value* v) {
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::reduce_inst*>(v))
|
||||
return true;
|
||||
if(auto x = dynamic_cast<ir::phi_node*>(v)){
|
||||
bool res = true;
|
||||
for(unsigned inc = 0; inc < x->get_num_incoming(); inc++)
|
||||
|
@@ -60,6 +60,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v))
|
||||
return;
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
|
165
lib/dnn/batchnorm.cpp
Normal file
165
lib/dnn/batchnorm.cpp
Normal file
@@ -0,0 +1,165 @@
|
||||
/* Copyright 2015-2019 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
/* ---------------
|
||||
* Forward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) { }
|
||||
|
||||
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *b,
|
||||
size_t, size_t nthreads) {
|
||||
|
||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
||||
kernel->setArg(0, y);
|
||||
kernel->setArg(1, m);
|
||||
kernel->setArg(2, v);
|
||||
kernel->setArg(3, x);
|
||||
kernel->setArg(4, g);
|
||||
kernel->setArg(5, b);
|
||||
kernel->setArg(6, (int32_t)(D_*H_*W_*B_));
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_forward::src(std::ostream &os) {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *B,
|
||||
int32 DHWN) {
|
||||
int32 rx[TM] = get_global_range[TM](0);
|
||||
fp32 *px[TM];
|
||||
fp32 x[TM];
|
||||
int32 c = get_range_id(0);
|
||||
fp32 g = *(G + c);
|
||||
fp32 b = *(B + c);
|
||||
|
||||
fp32 mean[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 m = __sum(mean);
|
||||
fp32 *pm = M + c;
|
||||
*pm = m;
|
||||
|
||||
fp32 var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
x = x - mean;
|
||||
var = var + x*x;
|
||||
}
|
||||
fp32 v = __sum(var);
|
||||
fp32 *pv = V + c;
|
||||
*pv = v;
|
||||
})";
|
||||
}
|
||||
|
||||
/* ---------------
|
||||
* Backward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty)
|
||||
{ }
|
||||
|
||||
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t, size_t nthreads) {
|
||||
|
||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
||||
kernel->setArg(0, dx);
|
||||
kernel->setArg(1, dg);
|
||||
kernel->setArg(2, db);
|
||||
kernel->setArg(3, dy);
|
||||
kernel->setArg(4, x);
|
||||
kernel->setArg(5, g);
|
||||
kernel->setArg(6, m);
|
||||
kernel->setArg(7, v);
|
||||
kernel->setArg(8, (int32_t)(D_*H_*W_*B_));
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_backward::src(std::ostream &os) {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
restrict read_only fp32 *DY,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *M,
|
||||
restrict read_only fp32 *V,
|
||||
int32 DHWN) {
|
||||
int32 rx[TM] = get_global_range[TM](0);
|
||||
int32 c = get_range_id(0);
|
||||
int32 offset = c*DHWN;
|
||||
fp32 g = *(G + c);
|
||||
fp32 mean = *(M + c);
|
||||
fp32 var = *(V + c);
|
||||
fp32 rstd = var;
|
||||
fp32* px[TM];
|
||||
fp32* pdx[TM];
|
||||
fp32* pdy[TM];
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
fp32 dg[TM] = 0;
|
||||
fp32 db[TM] = 0;
|
||||
for(int32 i = 0; i < DHWN; i += TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
}
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int32 i = 0; i < DHWN; i += TM){
|
||||
fp32 xhat[TM] = (x - mean) * rstd;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
|
||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
}
|
||||
})";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -220,8 +220,10 @@ if(ty_ == WGRAD){
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = *pa;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << ";";
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(;
|
||||
int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0;)";
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 rbwhc[TK] = rkb / ABS;
|
||||
@@ -233,11 +235,11 @@ if(ty_ == WGRAD){
|
||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
int32 inc[TK, TN] = interior ? shift : 0;
|
||||
)" << b_ty_ << R"(* shifted_pb[TK, TN] = pb + inc;
|
||||
)" << b_ty_ << R"( b[TK, TN] = *shifted_pb;)";
|
||||
)" << b_ty_ << R"( b[TK, TN] = checkb ? *shifted_pb : 0;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
)" << b_ty_ << " b[" << BS << R"(] = *pb;)";
|
||||
)" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0;)";
|
||||
}
|
||||
os << R"(
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -320,6 +320,10 @@ value *builder::create_trans(value *A, const std::string &name) {
|
||||
return insert(trans_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
||||
return insert(select_inst::create(pred, if_value, else_value, name));
|
||||
}
|
||||
|
@@ -566,6 +566,19 @@ instruction* trans_inst::create(value *arg, const std::string &name, instruction
|
||||
return new trans_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
reduce_inst::reduce_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type()->get_scalar_ty(), 1, 1, name, next) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* reduce_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// select instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -175,11 +175,16 @@ ir::value* select_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_select(pred, if_value, else_value);
|
||||
}
|
||||
|
||||
// Trans
|
||||
// trans
|
||||
ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
// reduce
|
||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_reduce(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
ir::value *in = lhs_->codegen(mod);
|
||||
|
Reference in New Issue
Block a user