diff --git a/examples/python/pytorch/batchnorm.cpp b/examples/python/pytorch/batchnorm.cpp index fff9039d7..521137a9e 100644 --- a/examples/python/pytorch/batchnorm.cpp +++ b/examples/python/pytorch/batchnorm.cpp @@ -4,11 +4,18 @@ #include "triton/driver/stream.h" #include "triton/dnn/batchnorm.h" +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + std::vector batchnorm_ymv(const torch::Tensor fw_x, const torch::Tensor fw_g, const torch::Tensor fw_b, - float eps) { + double eps) { + CHECK_INPUT(fw_x); + CHECK_INPUT(fw_g); + CHECK_INPUT(fw_b); // Wrap CUDA handles c10::DeviceIndex device = fw_x.storage().device().index(); CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream(); @@ -30,8 +37,9 @@ std::vector triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false); triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false); // create template - triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps); + triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b}); + stream.synchronize(); return {fw_y, fw_m, fw_v}; } @@ -41,7 +49,12 @@ std::vector const torch::Tensor fw_g, const torch::Tensor fw_m, const torch::Tensor fw_v, - float eps) { + double eps) { + CHECK_INPUT(fw_dy); + CHECK_INPUT(fw_x); + CHECK_INPUT(fw_g); + CHECK_INPUT(fw_m); + CHECK_INPUT(fw_v); // Wrap CUDA handles c10::DeviceIndex device = fw_x.storage().device().index(); CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream(); @@ -68,4 +81,10 @@ std::vector // create config triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps); batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); + stream.synchronize(); + return {fw_dx, fw_dg, fw_db}; } + +static auto registry = + torch::jit::RegisterOperators("triton::batchnorm_ymv", &batchnorm_ymv) + .op("triton::batchnorm_dxdgdb", &batchnorm_dxdgdb); diff --git a/examples/python/pytorch/run.py b/examples/python/pytorch/run.py new file mode 100644 index 000000000..86f489ccc --- /dev/null +++ b/examples/python/pytorch/run.py @@ -0,0 +1,120 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import triton + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) + self.bn1 = triton.BatchNorm2d(32) + self.conv2 = triton.ShiftConv2d(32, 64, 3, 2) + self.bn2 = triton.BatchNorm2d(64) + self.fc1 = nn.Linear(64*7*7, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = x.permute(1, 2, 3, 0).contiguous() + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = x.permute(3, 0, 1, 2).contiguous() + x = x.view(-1, 64*7*7) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.batch_size, shuffle=True, **kwargs) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.test_batch_size, shuffle=True, **kwargs) + + + model = Net().to(device) + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(args, model, device, test_loader) + + if (args.save_model): + torch.save(model.state_dict(),"mnist_cnn.pt") + +main() diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index e9271d37c..bde690b27 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -35,8 +35,11 @@ torch::Tensor shift_common( triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false); triton::driver::buffer* bias = has_bias ? &cubias : nullptr; // Allocate output - std::vector c_shapes = shift.c_shapes(); - torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda(); + std::vector _c_shapes = shift.c_shapes(); + std::vector c_shapes; + for(auto x: _c_shapes) + c_shapes.push_back(x); + torch::Tensor torchc = torch::empty(c_shapes).cuda(); triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); // Enqueue shift.enqueue(&stream, {&a, &b, &c}); @@ -47,9 +50,9 @@ torch::Tensor shift_y( const torch::Tensor x, const torch::Tensor w, const torch::Tensor bias, - int32_t R, int32_t S, - int32_t stride_h, int32_t stride_w, - int32_t* shift_h, int32_t* shift_w) { + int64_t R, int64_t S, + int64_t stride_h, int64_t stride_w, + const torch::Tensor shift_h, const torch::Tensor shift_w) { // shapes for a int64_t Ca = x.size(0); int64_t H = x.size(1); @@ -61,16 +64,18 @@ torch::Tensor shift_y( AT_CHECK(Ca == Cb, "operands must have the same number of channels"); int64_t C = Ca; // run - shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias); + return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, + (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), + triton::dnn::shift::FPROP, x, w, bias); } torch::Tensor shift_dx( const torch::Tensor dy, const torch::Tensor w, const torch::Tensor bias, - int32_t R, int32_t S, - int32_t stride_h, int32_t stride_w, - int32_t* shift_h, int32_t* shift_w) { + int64_t R, int64_t S, + int64_t stride_h, int64_t stride_w, + const torch::Tensor shift_h, const torch::Tensor shift_w) { // shapes for a int64_t Ca = dy.size(0); int64_t H = dy.size(1); @@ -87,16 +92,18 @@ torch::Tensor shift_dx( int64_t C = Ca; std::swap(C, F); // run - shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias); + return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, + (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), + triton::dnn::shift::BPROP, dy, w, bias); } torch::Tensor shift_dw( const torch::Tensor dy, const torch::Tensor x, const torch::Tensor bias, - int32_t R, int32_t S, - int32_t stride_h, int32_t stride_w, - int32_t* shift_h, int32_t* shift_w) { + int64_t R, int64_t S, + int64_t stride_h, int64_t stride_w, + const torch::Tensor shift_h, const torch::Tensor shift_w) { // shapes for a int64_t F = dy.size(0); int64_t Ha = dy.size(1); @@ -115,7 +122,9 @@ torch::Tensor shift_dw( int64_t W = Wb; int64_t B = Bb; // run - shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias); + return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, + (int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(), + triton::dnn::shift::WGRAD, dy, x, bias); } static auto registry = diff --git a/examples/python/pytorch/triton.py b/examples/python/pytorch/triton.py index e8a0a7ff2..3b7b38b87 100644 --- a/examples/python/pytorch/triton.py +++ b/examples/python/pytorch/triton.py @@ -1,6 +1,7 @@ import torch -from torch.nn.modules.utils import _single, _pair, _triple import math +from torch.nn.modules.utils import _single, _pair, _triple +from torch.distributions import categorical torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so") @@ -93,13 +94,18 @@ class Conv2d(_ConvNd): class ShiftConvFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, stride, width): + def forward(ctx, input, weight, bias, stride, width, shift_h, shift_w): if bias is None: bias = torch.empty(0) ctx.save_for_backward(input, weight, bias) ctx.stride = stride ctx.width = width - output = torch.ops.triton.shift_conv_y(input, weight, bias, width[0], width[1], stride[0], stride[1]) + ctx.shift_h = shift_h + ctx.shift_w = shift_w + output = torch.ops.triton.shift_conv_y(input, weight, bias, + width[0], width[1], + stride[0], stride[1], + shift_h, shift_w) return output @staticmethod @@ -107,16 +113,70 @@ class ShiftConvFunction(torch.autograd.Function): input, weight, bias = ctx.saved_tensors stride = ctx.stride width = ctx.width + shift_h = ctx.shift_h + shift_w = ctx.shift_w dx = dw = dbias = None if ctx.needs_input_grad[0]: - dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1]) + dx = torch.ops.triton.shift_conv_dx(dy, weight, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w) if ctx.needs_input_grad[1]: - dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1]) + dw = torch.ops.triton.shift_conv_dw(dy, input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w) if ctx.needs_input_grad[2]: dbias = torch.sum(dy, (1, 2, 3)) - return dx, dw, dbias, None, None + return dx, dw, dbias, None, None, None, None +class _ShiftConvNd(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, bias): + super(_ShiftConvNd, self).__init__() + # initialize + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.weight = torch.nn.Parameter(torch.Tensor(in_channels, out_channels)) + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.shift_h = self.make_shift(kernel_size[0]) + self.shift_w = self.make_shift(kernel_size[1]) + self.reset_parameters() + + def forward(self, input): + return ShiftConvFunction.apply(input, self.weight, self.bias, self.stride, + self.kernel_size, self.shift_h, self.shift_w) + + def make_shift(self, kernel_size): + if kernel_size == 3: + p = torch.Tensor([0.3, 0.4, 0.3]) + elif kernel_size == 5: + p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1]) + elif kernel_size == 7: + p = torch.Tensor([0.075, 0.1, 0.175, 0.3, 0.175, 0.1, 0.075]) + elif kernel_size == 9: + p = torch.Tensor([0.05, 0.075, 0.1, 0.175, 0.2, 0.175, 0.1, 0.075, 0.05]) + else: + raise RuntimeError('Unsupported kernel size') + return categorical.Categorical(p).sample((self.in_channels,)) - (kernel_size // 2) + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + +class ShiftConv2d(_ShiftConvNd): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + super(ShiftConv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, bias) + ################################# ######### BatchNorm ########### ################################# @@ -134,5 +194,27 @@ class BatchNormFunction(torch.autograd.Function): def backward(ctx, dy): eps = ctx.eps x, gamma, beta, mean, var = ctx.saved_tensors - dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy, x, gamma, mean, var, eps) - return dx, dg, db, None, None + dx, dg, db = torch.ops.triton.batchnorm_dxdgdb(dy.contiguous(), x, gamma, mean, var, eps) + return dx, dg, db, None + + +class _BatchNorm(torch.nn.Module): + + def __init__(self, num_features, eps=1e-5): + super(_BatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.weight = torch.nn.Parameter(torch.Tensor(num_features)) + self.bias = torch.nn.Parameter(torch.Tensor(num_features)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.uniform_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, input): + return BatchNormFunction.apply(input, self.weight, self.bias, self.eps) + +class BatchNorm2d(_BatchNorm): + + pass diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 1d5fba379..c15b0e8af 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var): def run_batchnorm(): - C, H, W, B = 1, 4, 4, 4 + C, H, W, B = 32, 14, 14, 64 np.random.seed(0) # Placeholders x = tf.placeholder(tf.float32, shape=[C, H, W, B]) @@ -117,7 +117,8 @@ def run_batchnorm(): sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb}) - #print(result[0], result[1], result[2]) + print(result[1]) + print(np.mean(hx, (1, 2, 3))) grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B), extra_feed_dict = {x: hx, g: hg, b: hb}) dx_t, dx_n = grads[0] diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index 3085a5b44..54bb9c16e 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -59,7 +59,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker { driver::buffer *y = args[0], *m = args[1], *v = args[2]; driver::buffer *x = args[3], *g = args[4], *b = args[5]; - std::array grid = {(size_t)C_, 1, 1}; + std::array grid = {1, (size_t)C_, 1}; kernel->setArg(0, y); kernel->setArg(1, m); kernel->setArg(2, v); @@ -86,7 +86,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, int32 rx[TM] = 0 ... TM; fp32 *px[TM]; fp32 x[TM]; - int32 c = get_range_id(0); + int32 c = get_range_id(1); fp32 g = *(G + c); fp32 b = *(B + c); @@ -112,7 +112,6 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, fp32 v = __sum(var) * rcpDHWN; fp32 *pv = V + c; *pv = v; - fp32 rstdg = 1 / sqrt(v + eps) * g; px = X + rx + c*DHWN; @@ -186,7 +185,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, restrict read_only fp32 *V, int32 DHWN, fp32 rcpDHWN, fp32 epsilon) { int32 rx[TM] = 0 ... TM; - int32 c = get_range_id(0); + int32 c = get_range_id(1); int32 offset = c*DHWN; fp32 g = *(G + c); fp32 mean = *(M + c);