[dnn/conv] Added bias and forward stride

This commit is contained in:
Philippe Tillet
2019-05-20 12:20:29 -04:00
parent f33a1f3fe3
commit e8f23bcade
8 changed files with 303 additions and 199 deletions

View File

@@ -1,117 +1,142 @@
import argparse
import triton
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
torch.manual_seed(0)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, (5,5))
self.conv2 = nn.Conv2d(20, 50, (5,5))
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
import torchvision
import torchvision.transforms as transforms
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import os
import argparse
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
from resnet import *
from utils import progress_bar
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Model
print('==> Building model..')
# net = VGG('VGG19')
net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
#net = ShuffleNetV2(1)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss /= len(test_loader.dataset)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")
if __name__ == '__main__':
main()
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)

View File

@@ -14,7 +14,7 @@ typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t,
triton::dnn::conv::type> conv_key_t;
triton::dnn::conv::type, bool> conv_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
@@ -26,7 +26,7 @@ torch::Tensor conv_common(
int32_t stride_d, int32_t stride_h, int32_t stride_w,
int32_t pad_d, int32_t pad_h, int32_t pad_w,
triton::dnn::conv::type ty,
torch::Tensor torcha, torch::Tensor torchb
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias
) {
// Wrap CUDA handles
c10::DeviceIndex device = torcha.storage().device().index();
@@ -40,13 +40,16 @@ torch::Tensor conv_common(
// Get context
triton::driver::context* ctx = stream->context();
// Get configuration
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
bool has_bias = torchbias.storage().size() > 0;
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
triton::dnn::conv* configuration;
if(m_config.find(key) == m_config.end())
configuration = m_config.emplace(key, new triton::dnn::conv(
B, C, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w, ty)).first->second.get();
pad_d, pad_h, pad_w,
1, 1, 1,
ty, has_bias)).first->second.get();
else
configuration = m_config.at(key).get();
// Get JIT
@@ -55,12 +58,16 @@ torch::Tensor conv_common(
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::string src = configuration->src();
jit->add_module("conv", src.c_str(), configuration->default_params());
triton::driver::kernel* kernel = jit->get_function("conv");
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Get memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
// Allocate output
std::vector<int32_t> c_shapes = configuration->c_shapes();
torch::Tensor torchc;
@@ -76,10 +83,9 @@ torch::Tensor conv_common(
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
// launch info
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
configuration->set_arg(kernel, &a, &b, &c);
configuration->set_arg(kernel, &a, &b, &c, bias);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
return torchc;
}
@@ -87,6 +93,8 @@ torch::Tensor conv_common(
torch::Tensor conv_fprop(
const torch::Tensor data,
const torch::Tensor weight,
const torch::Tensor bias,
int64_t stride_h, int64_t stride_w,
int64_t pad_h, int64_t pad_w) {
// Check
CHECK_INPUT(data);
@@ -104,16 +112,19 @@ torch::Tensor conv_fprop(
const int32_t S = weight.size(2);
const int32_t NF = weight.size(3);
// Configuration
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t stride_d = 1;
const int32_t pad_d = 0;
// Check
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight, bias);
}
torch::Tensor conv_bprop(
const torch::Tensor derror,
const torch::Tensor weight,
const torch::Tensor bias,
int64_t H, int64_t W,
int64_t stride_h, int64_t stride_w,
int64_t pad_h, int64_t pad_w){
// Check
CHECK_INPUT(derror);
@@ -131,22 +142,20 @@ torch::Tensor conv_bprop(
const int32_t S = weight.size(2);
const int32_t Kw = weight.size(3);
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t stride_d = 1;
int32_t pad_d = 0;
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
int32_t D = 1;
// Check
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight, bias);
}
torch::Tensor conv_wgrad(
const torch::Tensor data,
const torch::Tensor derror,
const torch::Tensor bias,
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
int64_t pad_h, int64_t pad_w
){
// Check
@@ -166,16 +175,12 @@ torch::Tensor conv_wgrad(
const int32_t Q = derror.size(3);
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t stride_d = 1;
const int32_t pad_d = 0;
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
const int32_t T = 1;
// Check
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror, bias);
}
static auto registry =

View File

@@ -2,6 +2,7 @@ import torch
import triton
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
bias = torch.autograd.Variable(torch.randn(64).cuda(), requires_grad=True)
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
@@ -10,13 +11,16 @@ def run(x, w, conv):
y = conv(x, w)
loss = (y - y_target).norm(2)
loss.backward()
return loss, y.clone(), x.grad.clone(), w.grad.clone()
return loss, y.clone(), x.grad.clone(), w.grad.clone(), bias.grad.clone()
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (0,0)))
x.grad.zero_()
w.grad.zero_()
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
bias.grad.zero_()
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=0))
print(ttdx[0,0,:,:], cudx[0,0,:,:])
print((tty - cuy).norm(2))
print((ttdx - cudx).norm(2))
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
print((ttbias - cubias).norm(2))

View File

@@ -1,4 +1,5 @@
import torch
from torch.nn.modules.utils import _single, _pair, _triple
import math
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
@@ -6,37 +7,56 @@ torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/
class ConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, padding):
ctx.save_for_backward(input, weight)
def forward(ctx, input, weight, bias, stride, padding):
if bias is None:
bias = torch.empty(0)
ctx.save_for_backward(input, weight, bias)
ctx.stride = stride
ctx.padding = padding
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
output = torch.ops.triton.conv_fprop(input, weight, bias, stride[0], stride[1], padding[0], padding[1])
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
grad_input = grad_weight = None
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, bias, input.shape[2], input.shape[3], stride[0], stride[1], padding[0], padding[1])
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
return grad_input, grad_weight, None
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, bias, weight.shape[1], weight.shape[2], stride[0], stride[1], padding[0], padding[1])
if ctx.needs_input_grad[2]:
grad_bias = torch.sum(grad_output, (0, 2, 3))
return grad_input, grad_weight, grad_bias, None, None
class Conv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
super(Conv2d, self).__init__()
class _ConvNd(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNd, self).__init__()
# not everything is supported by Triton
assert all(x==1 for x in stride)
assert all(x==1 for x in dilation)
assert transposed == False
assert all(x==0 for x in output_padding)
assert groups == 1
# initialize
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = torch.nn.Parameter(torch.Tensor(
in_channels, kernel_size[0], kernel_size[1], out_channels))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def forward(self, input):
return ConvFunction.apply(input, self.weight, self.padding)
return ConvFunction.apply(input, self.weight, self.bias, self.padding)
def reset_parameters(self):
n = self.in_channels
@@ -44,3 +64,19 @@ class Conv2d(torch.nn.Module):
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
class Conv2d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias)