[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm

This commit is contained in:
Philippe Tillet
2019-07-09 21:54:37 -07:00
parent 63b249c1d6
commit 3b89bc8463
8 changed files with 63 additions and 151 deletions

View File

@@ -4,7 +4,7 @@ if(${TORCH_FOUND})
include_directories(${TORCH_INCLUDE_DIRS})
include_directories("${CUDA_HOME}/include")
link_directories(${TORCH_LIBRARY_DIRS})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1)
add_library(torch_triton SHARED conv.cpp shift.cpp batchnorm.cpp)
target_link_libraries(torch_triton torch triton)
endif()

View File

@@ -3,7 +3,6 @@
#include "ATen/cuda/CUDAContext.h"
#include "triton/driver/stream.h"
#include "triton/dnn/batchnorm.h"
#include "triton/tools/bench.hpp"
std::vector<torch::Tensor>
batchnorm_ymv(const torch::Tensor fw_x,

View File

@@ -1,142 +0,0 @@
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from resnet import *
from utils import progress_bar
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model
print('==> Building model..')
# net = VGG('VGG19')
net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
#net = ShuffleNetV2(1)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)

View File

@@ -1,12 +1,9 @@
#include <vector>
#include <sstream>
#include <torch/torch.h>
#include <torch/script.h>
#include "ATen/cuda/CUDAContext.h"
#include "triton/runtime/jit.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
#include "triton/tools/bench.hpp"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")

View File

@@ -1,12 +1,9 @@
#include <vector>
#include <sstream>
#include <torch/torch.h>
#include <torch/script.h>
#include "ATen/cuda/CUDAContext.h"
#include "triton/runtime/jit.h"
#include "triton/driver/stream.h"
#include "triton/dnn/shift.h"
#include "triton/tools/bench.hpp"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
@@ -120,3 +117,8 @@ torch::Tensor shift_dw(
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
}
static auto registry =
torch::jit::RegisterOperators("triton::shift_conv_y", &shift_y)
.op("triton::shift_conv_dx", &shift_dx)
.op("triton::shift_conv_dw", &shift_dw);

View File

@@ -4,6 +4,10 @@ import math
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
#################################
####### Convolutions ##########
#################################
class ConvFunction(torch.autograd.Function):
@staticmethod
@@ -81,3 +85,54 @@ class Conv2d(_ConvNd):
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)
#################################
#### Shift-Convolutions #######
#################################
class ShiftConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, width):
if bias is None:
bias = torch.empty(0)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.width = width
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
return output
@staticmethod
def backward(ctx, dy):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
width = ctx.width
dx = dw = dbias = None
if ctx.needs_input_grad[0]:
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
if ctx.needs_input_grad[1]:
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
return dx, dw, dbias, None, None
#################################
######### BatchNorm ###########
#################################
class BatchNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, eps):
ctx.eps = eps
y, mean, var = torch.ops.triton.batchnorm_ymv(x, gamma, beta, eps)
ctx.save_for_backward(x, gamma, beta, mean, var)
return y
@staticmethod
def backward(ctx, dy):
eps = ctx.eps
x, gamma, beta, mean, var = ctx.saved_tensors
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
return dx, dg, db, None, None

View File

@@ -31,7 +31,6 @@
#include "triton/dnn/base.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/runtime/jit.h"
namespace triton{
namespace dnn{

View File

@@ -1,6 +1,8 @@
#ifndef TRITON_TOOLS_BENCH_HPP
#define TRITON_TOOLS_BENCH_HPP
#include <chrono>
namespace triton{
namespace tools{