[examples/pytorch] Fixed issues in backward pass of conv
This commit is contained in:
117
examples/python/pytorch/bench.py
Normal file
117
examples/python/pytorch/bench.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import argparse
|
||||
import triton
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, (5,5))
|
||||
self.conv2 = nn.Conv2d(20, 50, (5,5))
|
||||
self.fc1 = nn.Linear(4*4*50, 500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4*4*50)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
def test(args, model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--save-model', action='store_true', default=False,
|
||||
help='For Saving the current Model')
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
model = Net().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(args, model, device, test_loader)
|
||||
|
||||
if (args.save_model):
|
||||
torch.save(model.state_dict(),"mnist_cnn.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -86,7 +86,8 @@ torch::Tensor conv_common(
|
||||
|
||||
torch::Tensor conv_fprop(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -104,7 +105,7 @@ torch::Tensor conv_fprop(
|
||||
const int32_t NF = weight.size(3);
|
||||
// Configuration
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t pad_d = 0;
|
||||
// Check
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
||||
@@ -112,7 +113,8 @@ torch::Tensor conv_fprop(
|
||||
|
||||
torch::Tensor conv_bprop(
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor weight){
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w){
|
||||
// Check
|
||||
CHECK_INPUT(derror);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -131,10 +133,12 @@ torch::Tensor conv_bprop(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
|
||||
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
|
||||
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
|
||||
int32_t pad_d = 0;
|
||||
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
|
||||
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
|
||||
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
||||
@@ -142,17 +146,18 @@ torch::Tensor conv_bprop(
|
||||
|
||||
torch::Tensor conv_wgrad(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor derror
|
||||
const torch::Tensor derror,
|
||||
int64_t pad_h, int64_t pad_w
|
||||
){
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(derror);
|
||||
// Unpack data shapes
|
||||
const int32_t Ba = derror.size(0);
|
||||
const int32_t C = derror.size(1);
|
||||
const int32_t Ba = data.size(0);
|
||||
const int32_t C = data.size(1);
|
||||
const int32_t D = 1;
|
||||
const int32_t H = derror.size(2);
|
||||
const int32_t W = derror.size(3);
|
||||
const int32_t H = data.size(2);
|
||||
const int32_t W = data.size(3);
|
||||
// Unpack error shapes
|
||||
const int32_t Bb = derror.size(0);
|
||||
const int32_t K = derror.size(1);
|
||||
@@ -162,10 +167,12 @@ torch::Tensor conv_wgrad(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
|
||||
const int32_t pad_d = 0;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
||||
|
@@ -1,50 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
torch.manual_seed(0)
|
||||
|
||||
class TritonConv(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
output = torch.ops.triton.conv_fprop(input, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
loss = (y - y_target).norm(2)
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply)
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
|
||||
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
#print(ttdx)
|
||||
#print(cudx)
|
||||
#print(ttdw)
|
||||
#print(cudw)
|
||||
#print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
22
examples/python/pytorch/test.py
Normal file
22
examples/python/pytorch/test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
loss = (y - y_target).norm(2)
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
46
examples/python/pytorch/triton.py
Normal file
46
examples/python/pytorch/triton.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
class ConvFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, padding):
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.padding = padding
|
||||
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
padding = ctx.padding
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
|
||||
super(Conv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(
|
||||
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input):
|
||||
return ConvFunction.apply(input, self.weight, self.padding)
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
Reference in New Issue
Block a user