testing a simple shiftnet
This commit is contained in:
@@ -4,11 +4,18 @@
|
|||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/dnn/batchnorm.h"
|
#include "triton/dnn/batchnorm.h"
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
std::vector<torch::Tensor>
|
std::vector<torch::Tensor>
|
||||||
batchnorm_ymv(const torch::Tensor fw_x,
|
batchnorm_ymv(const torch::Tensor fw_x,
|
||||||
const torch::Tensor fw_g,
|
const torch::Tensor fw_g,
|
||||||
const torch::Tensor fw_b,
|
const torch::Tensor fw_b,
|
||||||
float eps) {
|
double eps) {
|
||||||
|
CHECK_INPUT(fw_x);
|
||||||
|
CHECK_INPUT(fw_g);
|
||||||
|
CHECK_INPUT(fw_b);
|
||||||
// Wrap CUDA handles
|
// Wrap CUDA handles
|
||||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||||
@@ -30,8 +37,9 @@ std::vector<torch::Tensor>
|
|||||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
||||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
||||||
// create template
|
// create template
|
||||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps);
|
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||||
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
||||||
|
stream.synchronize();
|
||||||
return {fw_y, fw_m, fw_v};
|
return {fw_y, fw_m, fw_v};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +49,12 @@ std::vector<torch::Tensor>
|
|||||||
const torch::Tensor fw_g,
|
const torch::Tensor fw_g,
|
||||||
const torch::Tensor fw_m,
|
const torch::Tensor fw_m,
|
||||||
const torch::Tensor fw_v,
|
const torch::Tensor fw_v,
|
||||||
float eps) {
|
double eps) {
|
||||||
|
CHECK_INPUT(fw_dy);
|
||||||
|
CHECK_INPUT(fw_x);
|
||||||
|
CHECK_INPUT(fw_g);
|
||||||
|
CHECK_INPUT(fw_m);
|
||||||
|
CHECK_INPUT(fw_v);
|
||||||
// Wrap CUDA handles
|
// Wrap CUDA handles
|
||||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||||
@@ -68,4 +81,10 @@ std::vector<torch::Tensor>
|
|||||||
// create config
|
// create config
|
||||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||||
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||||
|
stream.synchronize();
|
||||||
|
return {fw_dx, fw_dg, fw_db};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static auto registry =
|
||||||
|
torch::jit::RegisterOperators("triton::batchnorm_ymv", &batchnorm_ymv)
|
||||||
|
.op("triton::batchnorm_dxdgdb", &batchnorm_dxdgdb);
|
||||||
|
120
examples/python/pytorch/run.py
Normal file
120
examples/python/pytorch/run.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
import triton
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
||||||
|
self.bn1 = triton.BatchNorm2d(32)
|
||||||
|
self.conv2 = triton.ShiftConv2d(32, 64, 3, 2)
|
||||||
|
self.bn2 = triton.BatchNorm2d(64)
|
||||||
|
self.fc1 = nn.Linear(64*7*7, 500)
|
||||||
|
self.fc2 = nn.Linear(500, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.permute(1, 2, 3, 0).contiguous()
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = x.permute(3, 0, 1, 2).contiguous()
|
||||||
|
x = x.view(-1, 64*7*7)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = self.fc2(x)
|
||||||
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
|
def train(args, model, device, train_loader, optimizer, epoch):
|
||||||
|
model.train()
|
||||||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(data)
|
||||||
|
loss = F.nll_loss(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if batch_idx % args.log_interval == 0:
|
||||||
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||||
|
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||||
|
100. * batch_idx / len(train_loader), loss.item()))
|
||||||
|
|
||||||
|
def test(args, model, device, test_loader):
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0
|
||||||
|
correct = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, target in test_loader:
|
||||||
|
data, target = data.to(device), target.to(device)
|
||||||
|
output = model(data)
|
||||||
|
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||||
|
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||||
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
|
||||||
|
test_loss /= len(test_loader.dataset)
|
||||||
|
|
||||||
|
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||||
|
test_loss, correct, len(test_loader.dataset),
|
||||||
|
100. * correct / len(test_loader.dataset)))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Training settings
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||||
|
help='input batch size for training (default: 64)')
|
||||||
|
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||||
|
help='input batch size for testing (default: 1000)')
|
||||||
|
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||||
|
help='number of epochs to train (default: 10)')
|
||||||
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||||
|
help='learning rate (default: 0.01)')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||||
|
help='SGD momentum (default: 0.5)')
|
||||||
|
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||||
|
help='disables CUDA training')
|
||||||
|
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||||
|
help='random seed (default: 1)')
|
||||||
|
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||||
|
help='how many batches to wait before logging training status')
|
||||||
|
|
||||||
|
parser.add_argument('--save-model', action='store_true', default=False,
|
||||||
|
help='For Saving the current Model')
|
||||||
|
args = parser.parse_args()
|
||||||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
|
||||||
|
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
datasets.MNIST('../data', train=True, download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.1307,), (0.3081,))
|
||||||
|
])),
|
||||||
|
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
datasets.MNIST('../data', train=False, transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.1307,), (0.3081,))
|
||||||
|
])),
|
||||||
|
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
model = Net().to(device)
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||||
|
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
|
test(args, model, device, test_loader)
|
||||||
|
|
||||||
|
if (args.save_model):
|
||||||
|
torch.save(model.state_dict(),"mnist_cnn.pt")
|
||||||
|
|
||||||
|
main()
|
@@ -35,8 +35,11 @@ torch::Tensor shift_common(
|
|||||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||||
// Allocate output
|
// Allocate output
|
||||||
std::vector<int32_t> c_shapes = shift.c_shapes();
|
std::vector<int32_t> _c_shapes = shift.c_shapes();
|
||||||
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
|
std::vector<long int> c_shapes;
|
||||||
|
for(auto x: _c_shapes)
|
||||||
|
c_shapes.push_back(x);
|
||||||
|
torch::Tensor torchc = torch::empty(c_shapes).cuda();
|
||||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||||
// Enqueue
|
// Enqueue
|
||||||
shift.enqueue(&stream, {&a, &b, &c});
|
shift.enqueue(&stream, {&a, &b, &c});
|
||||||
@@ -47,9 +50,9 @@ torch::Tensor shift_y(
|
|||||||
const torch::Tensor x,
|
const torch::Tensor x,
|
||||||
const torch::Tensor w,
|
const torch::Tensor w,
|
||||||
const torch::Tensor bias,
|
const torch::Tensor bias,
|
||||||
int32_t R, int32_t S,
|
int64_t R, int64_t S,
|
||||||
int32_t stride_h, int32_t stride_w,
|
int64_t stride_h, int64_t stride_w,
|
||||||
int32_t* shift_h, int32_t* shift_w) {
|
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||||
// shapes for a
|
// shapes for a
|
||||||
int64_t Ca = x.size(0);
|
int64_t Ca = x.size(0);
|
||||||
int64_t H = x.size(1);
|
int64_t H = x.size(1);
|
||||||
@@ -61,16 +64,18 @@ torch::Tensor shift_y(
|
|||||||
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
|
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
|
||||||
int64_t C = Ca;
|
int64_t C = Ca;
|
||||||
// run
|
// run
|
||||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias);
|
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||||
|
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||||
|
triton::dnn::shift::FPROP, x, w, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor shift_dx(
|
torch::Tensor shift_dx(
|
||||||
const torch::Tensor dy,
|
const torch::Tensor dy,
|
||||||
const torch::Tensor w,
|
const torch::Tensor w,
|
||||||
const torch::Tensor bias,
|
const torch::Tensor bias,
|
||||||
int32_t R, int32_t S,
|
int64_t R, int64_t S,
|
||||||
int32_t stride_h, int32_t stride_w,
|
int64_t stride_h, int64_t stride_w,
|
||||||
int32_t* shift_h, int32_t* shift_w) {
|
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||||
// shapes for a
|
// shapes for a
|
||||||
int64_t Ca = dy.size(0);
|
int64_t Ca = dy.size(0);
|
||||||
int64_t H = dy.size(1);
|
int64_t H = dy.size(1);
|
||||||
@@ -87,16 +92,18 @@ torch::Tensor shift_dx(
|
|||||||
int64_t C = Ca;
|
int64_t C = Ca;
|
||||||
std::swap(C, F);
|
std::swap(C, F);
|
||||||
// run
|
// run
|
||||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias);
|
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||||
|
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||||
|
triton::dnn::shift::BPROP, dy, w, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor shift_dw(
|
torch::Tensor shift_dw(
|
||||||
const torch::Tensor dy,
|
const torch::Tensor dy,
|
||||||
const torch::Tensor x,
|
const torch::Tensor x,
|
||||||
const torch::Tensor bias,
|
const torch::Tensor bias,
|
||||||
int32_t R, int32_t S,
|
int64_t R, int64_t S,
|
||||||
int32_t stride_h, int32_t stride_w,
|
int64_t stride_h, int64_t stride_w,
|
||||||
int32_t* shift_h, int32_t* shift_w) {
|
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||||
// shapes for a
|
// shapes for a
|
||||||
int64_t F = dy.size(0);
|
int64_t F = dy.size(0);
|
||||||
int64_t Ha = dy.size(1);
|
int64_t Ha = dy.size(1);
|
||||||
@@ -115,7 +122,9 @@ torch::Tensor shift_dw(
|
|||||||
int64_t W = Wb;
|
int64_t W = Wb;
|
||||||
int64_t B = Bb;
|
int64_t B = Bb;
|
||||||
// run
|
// run
|
||||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
|
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||||
|
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||||
|
triton::dnn::shift::WGRAD, dy, x, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto registry =
|
static auto registry =
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.utils import _single, _pair, _triple
|
|
||||||
import math
|
import math
|
||||||
|
from torch.nn.modules.utils import _single, _pair, _triple
|
||||||
|
from torch.distributions import categorical
|
||||||
|
|
||||||
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||||
|
|
||||||
@@ -93,13 +94,18 @@ class Conv2d(_ConvNd):
|
|||||||
class ShiftConvFunction(torch.autograd.Function):
|
class ShiftConvFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias, stride, width):
|
def forward(ctx, input, weight, bias, stride, width, shift_h, shift_w):
|
||||||
if bias is None:
|
if bias is None:
|
||||||
bias = torch.empty(0)
|
bias = torch.empty(0)
|
||||||
ctx.save_for_backward(input, weight, bias)
|
ctx.save_for_backward(input, weight, bias)
|
||||||
ctx.stride = stride
|
ctx.stride = stride
|
||||||
ctx.width = width
|
ctx.width = width
|
||||||
output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1])
|
ctx.shift_h = shift_h
|
||||||
|
ctx.shift_w = shift_w
|
||||||
|
output = torch.ops.triton.shift_conv_y(input, weight, bias,
|
||||||
|
width[0], width[1],
|
||||||
|
stride[0], stride[1],
|
||||||
|
shift_h, shift_w)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -107,16 +113,70 @@ class ShiftConvFunction(torch.autograd.Function):
|
|||||||
input, weight, bias = ctx.saved_tensors
|
input, weight, bias = ctx.saved_tensors
|
||||||
stride = ctx.stride
|
stride = ctx.stride
|
||||||
width = ctx.width
|
width = ctx.width
|
||||||
|
shift_h = ctx.shift_h
|
||||||
|
shift_w = ctx.shift_w
|
||||||
dx = dw = dbias = None
|
dx = dw = dbias = None
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1])
|
dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1])
|
dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||||
if ctx.needs_input_grad[2]:
|
if ctx.needs_input_grad[2]:
|
||||||
dbias = torch.sum(dy, (1, 2, 3))
|
dbias = torch.sum(dy, (1, 2, 3))
|
||||||
return dx, dw, dbias, None, None
|
return dx, dw, dbias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _ShiftConvNd(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias):
|
||||||
|
super(_ShiftConvNd, self).__init__()
|
||||||
|
# initialize
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels))
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.shift_h = self.make_shift(kernel_size[0])
|
||||||
|
self.shift_w = self.make_shift(kernel_size[1])
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return ShiftConvFunction.apply(input, self.weight, self.bias, self.stride,
|
||||||
|
self.kernel_size, self.shift_h, self.shift_w)
|
||||||
|
|
||||||
|
def make_shift(self, kernel_size):
|
||||||
|
if kernel_size == 3:
|
||||||
|
p = torch.Tensor([0.3, 0.4, 0.3])
|
||||||
|
elif kernel_size == 5:
|
||||||
|
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
||||||
|
elif kernel_size == 7:
|
||||||
|
p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075])
|
||||||
|
elif kernel_size == 9:
|
||||||
|
p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05])
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unsupported kernel size')
|
||||||
|
return categorical.Categorical(p).sample((self.in_channels,)) - (kernel_size // 2)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
n = self.in_channels
|
||||||
|
for k in self.kernel_size:
|
||||||
|
n *= k
|
||||||
|
stdv = 1. / math.sqrt(n)
|
||||||
|
self.weight.data.uniform_(-stdv, stdv)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
class ShiftConv2d(_ShiftConvNd):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False):
|
||||||
|
kernel_size = _pair(kernel_size)
|
||||||
|
stride = _pair(stride)
|
||||||
|
super(ShiftConv2d, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, bias)
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
######### BatchNorm ###########
|
######### BatchNorm ###########
|
||||||
#################################
|
#################################
|
||||||
@@ -134,5 +194,27 @@ class BatchNormFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, dy):
|
def backward(ctx, dy):
|
||||||
eps = ctx.eps
|
eps = ctx.eps
|
||||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||||
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps)
|
dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy.contiguous(), x, gamma, mean, var, eps)
|
||||||
return dx, dg, db, None, None
|
return dx, dg, db, None
|
||||||
|
|
||||||
|
|
||||||
|
class _BatchNorm(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_features, eps=1e-5):
|
||||||
|
super(_BatchNorm, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
torch.nn.init.uniform_(self.weight)
|
||||||
|
torch.nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return BatchNormFunction.apply(input, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
class BatchNorm2d(_BatchNorm):
|
||||||
|
|
||||||
|
pass
|
||||||
|
@@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var):
|
|||||||
|
|
||||||
|
|
||||||
def run_batchnorm():
|
def run_batchnorm():
|
||||||
C, H, W, B = 1, 4, 4, 4
|
C, H, W, B = 32, 14, 14, 64
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
# Placeholders
|
# Placeholders
|
||||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||||
@@ -117,7 +117,8 @@ def run_batchnorm():
|
|||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||||
#print(result[0], result[1], result[2])
|
print(result[1])
|
||||||
|
print(np.mean(hx, (1, 2, 3)))
|
||||||
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
|
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
|
||||||
extra_feed_dict = {x: hx, g: hg, b: hb})
|
extra_feed_dict = {x: hx, g: hg, b: hb})
|
||||||
dx_t, dx_n = grads[0]
|
dx_t, dx_n = grads[0]
|
||||||
|
@@ -59,7 +59,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
|
|||||||
{
|
{
|
||||||
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
||||||
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
||||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||||
kernel->setArg(0, y);
|
kernel->setArg(0, y);
|
||||||
kernel->setArg(1, m);
|
kernel->setArg(1, m);
|
||||||
kernel->setArg(2, v);
|
kernel->setArg(2, v);
|
||||||
@@ -86,7 +86,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
|||||||
int32 rx[TM] = 0 ... TM;
|
int32 rx[TM] = 0 ... TM;
|
||||||
fp32 *px[TM];
|
fp32 *px[TM];
|
||||||
fp32 x[TM];
|
fp32 x[TM];
|
||||||
int32 c = get_range_id(0);
|
int32 c = get_range_id(1);
|
||||||
fp32 g = *(G + c);
|
fp32 g = *(G + c);
|
||||||
fp32 b = *(B + c);
|
fp32 b = *(B + c);
|
||||||
|
|
||||||
@@ -112,7 +112,6 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
|||||||
fp32 v = __sum(var) * rcpDHWN;
|
fp32 v = __sum(var) * rcpDHWN;
|
||||||
fp32 *pv = V + c;
|
fp32 *pv = V + c;
|
||||||
*pv = v;
|
*pv = v;
|
||||||
|
|
||||||
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
||||||
|
|
||||||
px = X + rx + c*DHWN;
|
px = X + rx + c*DHWN;
|
||||||
@@ -186,7 +185,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
|||||||
restrict read_only fp32 *V,
|
restrict read_only fp32 *V,
|
||||||
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
|
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
|
||||||
int32 rx[TM] = 0 ... TM;
|
int32 rx[TM] = 0 ... TM;
|
||||||
int32 c = get_range_id(0);
|
int32 c = get_range_id(1);
|
||||||
int32 offset = c*DHWN;
|
int32 offset = c*DHWN;
|
||||||
fp32 g = *(G + c);
|
fp32 g = *(G + c);
|
||||||
fp32 mean = *(M + c);
|
fp32 mean = *(M + c);
|
||||||
|
Reference in New Issue
Block a user