[examples/pytorch] Fixed issues in backward pass of conv

This commit is contained in:
Philippe Tillet
2019-05-19 01:31:08 -04:00
parent b2b55c52c9
commit f33a1f3fe3
9 changed files with 541 additions and 71 deletions

View File

@@ -0,0 +1,117 @@
import argparse
import triton
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.manual_seed(0)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, (5,5))
self.conv2 = nn.Conv2d(20, 50, (5,5))
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")
if __name__ == '__main__':
main()

View File

@@ -86,7 +86,8 @@ torch::Tensor conv_common(
torch::Tensor conv_fprop(
const torch::Tensor data,
const torch::Tensor weight) {
const torch::Tensor weight,
int64_t pad_h, int64_t pad_w) {
// Check
CHECK_INPUT(data);
CHECK_INPUT(weight);
@@ -104,7 +105,7 @@ torch::Tensor conv_fprop(
const int32_t NF = weight.size(3);
// Configuration
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t pad_d = 0;
// Check
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
@@ -112,7 +113,8 @@ torch::Tensor conv_fprop(
torch::Tensor conv_bprop(
const torch::Tensor derror,
const torch::Tensor weight){
const torch::Tensor weight,
int64_t pad_h, int64_t pad_w){
// Check
CHECK_INPUT(derror);
CHECK_INPUT(weight);
@@ -131,10 +133,12 @@ torch::Tensor conv_bprop(
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
int32_t pad_d = 0;
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
// Check
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
@@ -142,17 +146,18 @@ torch::Tensor conv_bprop(
torch::Tensor conv_wgrad(
const torch::Tensor data,
const torch::Tensor derror
const torch::Tensor derror,
int64_t pad_h, int64_t pad_w
){
// Check
CHECK_INPUT(data);
CHECK_INPUT(derror);
// Unpack data shapes
const int32_t Ba = derror.size(0);
const int32_t C = derror.size(1);
const int32_t Ba = data.size(0);
const int32_t C = data.size(1);
const int32_t D = 1;
const int32_t H = derror.size(2);
const int32_t W = derror.size(3);
const int32_t H = data.size(2);
const int32_t W = data.size(3);
// Unpack error shapes
const int32_t Bb = derror.size(0);
const int32_t K = derror.size(1);
@@ -162,10 +167,12 @@ torch::Tensor conv_wgrad(
// Compute M, P, Q
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
const int32_t pad_d = 0;
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
// Check
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);

View File

@@ -1,50 +0,0 @@
import torch
import time
torch.manual_seed(0)
class TritonConv(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = torch.ops.triton.conv_fprop(input, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
return grad_input, grad_weight
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
def run(x, w, conv):
y = conv(x, w)
loss = (y - y_target).norm(2)
loss.backward()
return loss, y.clone(), x.grad.clone(), w.grad.clone()
ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply)
x.grad.zero_()
w.grad.zero_()
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
print((tty - cuy).norm(2))
print((ttdx - cudx).norm(2))
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
#print(ttdx)
#print(cudx)
#print(ttdw)
#print(cudw)
#print((ttdw.permute(3,0,1,2) - cudw).norm(2))

View File

@@ -0,0 +1,22 @@
import torch
import triton
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
def run(x, w, conv):
y = conv(x, w)
loss = (y - y_target).norm(2)
loss.backward()
return loss, y.clone(), x.grad.clone(), w.grad.clone()
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
x.grad.zero_()
w.grad.zero_()
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
print((tty - cuy).norm(2))
print((ttdx - cudx).norm(2))
print((ttdw.permute(3,0,1,2) - cudw).norm(2))

View File

@@ -0,0 +1,46 @@
import torch
import math
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
class ConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, padding):
ctx.save_for_backward(input, weight)
ctx.padding = padding
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
padding = ctx.padding
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
return grad_input, grad_weight, None
class Conv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
super(Conv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = torch.nn.Parameter(torch.Tensor(
in_channels, kernel_size[0], kernel_size[1], out_channels))
self.reset_parameters()
def forward(self, input):
return ConvFunction.apply(input, self.weight, self.padding)
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)