testing a simple shiftnet

This commit is contained in:
Philippe Tillet
2019-07-10 13:33:08 -07:00
parent 3b89bc8463
commit f665c742f9
6 changed files with 261 additions and 31 deletions

View File

@@ -4,11 +4,18 @@
#include "triton/driver/stream.h"
#include "triton/dnn/batchnorm.h"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
batchnorm_ymv(const torch::Tensor fw_x,
const torch::Tensor fw_g,
const torch::Tensor fw_b,
float eps) {
double eps) {
CHECK_INPUT(fw_x);
CHECK_INPUT(fw_g);
CHECK_INPUT(fw_b);
// Wrap CUDA handles
c10::DeviceIndex device = fw_x.storage().device().index();
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
@@ -30,8 +37,9 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
// create template
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps);
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
stream.synchronize();
return {fw_y, fw_m, fw_v};
}
@@ -41,7 +49,12 @@ std::vector<torch::Tensor>
const torch::Tensor fw_g,
const torch::Tensor fw_m,
const torch::Tensor fw_v,
float eps) {
double eps) {
CHECK_INPUT(fw_dy);
CHECK_INPUT(fw_x);
CHECK_INPUT(fw_g);
CHECK_INPUT(fw_m);
CHECK_INPUT(fw_v);
// Wrap CUDA handles
c10::DeviceIndex device = fw_x.storage().device().index();
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
@@ -68,4 +81,10 @@ std::vector<torch::Tensor>
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
stream.synchronize();
return {fw_dx, fw_dg, fw_db};
}
static auto registry =
torch::jit::RegisterOperators("triton::batchnorm_ymv", &batchnorm_ymv)
.op("triton::batchnorm_dxdgdb", &batchnorm_dxdgdb);

View File

@@ -0,0 +1,120 @@
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import triton
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
self.bn1 = triton.BatchNorm2d(32)
self.conv2 = triton.ShiftConv2d(32, 64, 3, 2)
self.bn2 = triton.BatchNorm2d(64)
self.fc1 = nn.Linear(64*7*7, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = x.permute(1, 2, 3, 0).contiguous()
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = x.permute(3, 0, 1, 2).contiguous()
x = x.view(-1, 64*7*7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")
main()

View File

@@ -35,8 +35,11 @@ torch::Tensor shift_common(
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
// Allocate output
std::vector<int32_t> c_shapes = shift.c_shapes();
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
std::vector<int32_t> _c_shapes = shift.c_shapes();
std::vector<long int> c_shapes;
for(auto x: _c_shapes)
c_shapes.push_back(x);
torch::Tensor torchc = torch::empty(c_shapes).cuda();
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue
shift.enqueue(&stream, {&a, &b, &c});
@@ -47,9 +50,9 @@ torch::Tensor shift_y(
const torch::Tensor x,
const torch::Tensor w,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t Ca = x.size(0);
int64_t H = x.size(1);
@@ -61,16 +64,18 @@ torch::Tensor shift_y(
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
int64_t C = Ca;
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::FPROP, x, w, bias);
}
torch::Tensor shift_dx(
const torch::Tensor dy,
const torch::Tensor w,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t Ca = dy.size(0);
int64_t H = dy.size(1);
@@ -87,16 +92,18 @@ torch::Tensor shift_dx(
int64_t C = Ca;
std::swap(C, F);
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::BPROP, dy, w, bias);
}
torch::Tensor shift_dw(
const torch::Tensor dy,
const torch::Tensor x,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t F = dy.size(0);
int64_t Ha = dy.size(1);
@@ -115,7 +122,9 @@ torch::Tensor shift_dw(
int64_t W = Wb;
int64_t B = Bb;
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::WGRAD, dy, x, bias);
}
static auto registry =

View File

@@ -1,6 +1,7 @@
import torch
from torch.nn.modules.utils import _single, _pair, _triple
import math
from torch.nn.modules.utils import _single, _pair, _triple
from torch.distributions import categorical
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
@@ -93,13 +94,18 @@ class Conv2d(_ConvNd):
class ShiftConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, stride, width):
def forward(ctx, input, weight, bias, stride, width, shift_h, shift_w):
if bias is None:
bias = torch.empty(0)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.width = width
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
ctx.shift_h = shift_h
ctx.shift_w = shift_w
output = torch.ops.triton.shift_conv_y(input, weight, bias,
width[0], width[1],
stride[0], stride[1],
shift_h, shift_w)
return output
@staticmethod
@@ -107,16 +113,70 @@ class ShiftConvFunction(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
width = ctx.width
shift_h = ctx.shift_h
shift_w = ctx.shift_w
dx = dw = dbias = None
if ctx.needs_input_grad[0]:
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[1]:
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
if ctx.needs_input_grad[2]:
dbias = torch.sum(dy, (1, 2, 3))
return dx, dw, dbias, None, None
return dx, dw, dbias, None, None, None, None
class _ShiftConvNd(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias):
super(_ShiftConvNd, self).__init__()
# initialize
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.shift_h = self.make_shift(kernel_size[0])
self.shift_w = self.make_shift(kernel_size[1])
self.reset_parameters()
def forward(self, input):
return ShiftConvFunction.apply(input, self.weight, self.bias, self.stride,
self.kernel_size, self.shift_h, self.shift_w)
def make_shift(self, kernel_size):
if kernel_size == 3:
p = torch.Tensor([0.3, 0.4, 0.3])
elif kernel_size == 5:
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
elif kernel_size == 7:
p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075])
elif kernel_size == 9:
p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05])
else:
raise RuntimeError('Unsupported kernel size')
return categorical.Categorical(p).sample((self.in_channels,)) - (kernel_size // 2)
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
class ShiftConv2d(_ShiftConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
super(ShiftConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, bias)
#################################
######### BatchNorm ###########
#################################
@@ -134,5 +194,27 @@ class BatchNormFunction(torch.autograd.Function):
def backward(ctx, dy):
eps = ctx.eps
x, gamma, beta, mean, var = ctx.saved_tensors
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
return dx, dg, db, None, None
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy.contiguous(), x, gamma, mean, var, eps)
return dx, dg, db, None
class _BatchNorm(torch.nn.Module):
def __init__(self, num_features, eps=1e-5):
super(_BatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.uniform_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, input):
return BatchNormFunction.apply(input, self.weight, self.bias, self.eps)
class BatchNorm2d(_BatchNorm):
pass

View File

@@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var):
def run_batchnorm():
C, H, W, B = 1, 4, 4, 4
C, H, W, B = 32, 14, 14, 64
np.random.seed(0)
# Placeholders
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
@@ -117,7 +117,8 @@ def run_batchnorm():
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
#print(result[0], result[1], result[2])
print(result[1])
print(np.mean(hx, (1, 2, 3)))
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
extra_feed_dict = {x: hx, g: hg, b: hb})
dx_t, dx_n = grads[0]

View File

@@ -59,7 +59,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
{
driver::buffer *y = args[0], *m = args[1], *v = args[2];
driver::buffer *x = args[3], *g = args[4], *b = args[5];
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
kernel->setArg(0, y);
kernel->setArg(1, m);
kernel->setArg(2, v);
@@ -86,7 +86,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
int32 rx[TM] = 0 ... TM;
fp32 *px[TM];
fp32 x[TM];
int32 c = get_range_id(0);
int32 c = get_range_id(1);
fp32 g = *(G + c);
fp32 b = *(B + c);
@@ -112,7 +112,6 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
fp32 v = __sum(var) * rcpDHWN;
fp32 *pv = V + c;
*pv = v;
fp32 rstdg = 1 / sqrt(v + eps) * g;
px = X + rx + c*DHWN;
@@ -186,7 +185,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
restrict read_only fp32 *V,
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
int32 rx[TM] = 0 ... TM;
int32 c = get_range_id(0);
int32 c = get_range_id(1);
int32 offset = c*DHWN;
fp32 g = *(G + c);
fp32 mean = *(M + c);