diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 70555fd0e..2bbec482a 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -10,13 +10,15 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); - triton::dnn::conv::type ty = triton::dnn::conv::BPROP; + triton::dnn::conv::type ty = triton::dnn::conv::FPROP; // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 56, W = 56; int32_t NC = 32, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 0, pad_w = 0; - triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty); + int32_t stride_d = 1, stride_h = 1, stride_w = 1; + int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; + triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty); // convolution configuration std::vector hc(configuration.c_size()); std::vector rc(configuration.c_size()); @@ -47,7 +49,7 @@ int main() { std::array grid = configuration.get_grid(TM, TN); configuration.init(stream, (triton::driver::cu_module*)kernel->module()); stream->synchronize(); - configuration.set_arg(kernel, da, db, dc); + configuration.set_arg(kernel, da, db, dc, nullptr); stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->synchronize(); double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, diff --git a/examples/python/pytorch/bench.py b/examples/python/pytorch/bench.py index 2c8c304b5..98a782099 100644 --- a/examples/python/pytorch/bench.py +++ b/examples/python/pytorch/bench.py @@ -1,117 +1,142 @@ -import argparse -import triton +'''Train CIFAR10 with PyTorch.''' +from __future__ import print_function + import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim -from torchvision import datasets, transforms +import torch.nn.functional as F +import torch.backends.cudnn as cudnn -torch.manual_seed(0) - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 20, (5,5)) - self.conv2 = nn.Conv2d(20, 50, (5,5)) - self.fc1 = nn.Linear(4*4*50, 500) - self.fc2 = nn.Linear(500, 10) +import torchvision +import torchvision.transforms as transforms - def forward(self, x): - x = F.relu(self.conv1(x)) - x = F.max_pool2d(x, 2, 2) - x = F.relu(self.conv2(x)) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4*4*50) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) +import os +import argparse -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) +from resnet import * +from utils import progress_bar + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.1, type=float, help='learning rate') +parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') +args = parser.parse_args() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +best_acc = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) +trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) + +testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + +# Model +print('==> Building model..') +# net = VGG('VGG19') +net = ResNet18() +# net = PreActResNet18() +# net = GoogLeNet() +# net = DenseNet121() +# net = ResNeXt29_2x64d() +# net = MobileNet() +# net = MobileNetV2() +# net = DPN92() +# net = ShuffleNetG2() +# net = SENet18() +#net = ShuffleNetV2(1) +net = net.to(device) +if device == 'cuda': + net = torch.nn.DataParallel(net) + cudnn.benchmark = True + +if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' + checkpoint = torch.load('./checkpoint/ckpt.t7') + net.load_state_dict(checkpoint['net']) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] + +criterion = nn.CrossEntropyLoss() +optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) + +# Training +def train(epoch): + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) + outputs = net(inputs) + loss = criterion(outputs, targets) loss.backward() optimizer.step() - if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) -def test(args, model, device, test_loader): - model.eval() + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + +def test(epoch): + global best_acc + net.eval() test_loss = 0 correct = 0 + total = 0 with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + loss = criterion(outputs, targets) - test_loss /= len(test_loader.dataset) + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) -def main(): - # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - - device = torch.device("cuda" if use_cuda else "cpu") - - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) + # Save checkpoint. + acc = 100.*correct/total + if acc > best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/ckpt.t7') + best_acc = acc - model = Net().to(device) - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) - - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - test(args, model, device, test_loader) - - if (args.save_model): - torch.save(model.state_dict(),"mnist_cnn.pt") - -if __name__ == '__main__': - main() +for epoch in range(start_epoch, start_epoch+200): + train(epoch) +test(epoch) diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index a3002e28a..4ed9785ed 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -14,7 +14,7 @@ typedef std::tuple conv_key_t; + triton::dnn::conv::type, bool> conv_key_t; static std::map> m_stream; static std::map> m_jit; @@ -26,7 +26,7 @@ torch::Tensor conv_common( int32_t stride_d, int32_t stride_h, int32_t stride_w, int32_t pad_d, int32_t pad_h, int32_t pad_w, triton::dnn::conv::type ty, - torch::Tensor torcha, torch::Tensor torchb + torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias ) { // Wrap CUDA handles c10::DeviceIndex device = torcha.storage().device().index(); @@ -40,13 +40,16 @@ torch::Tensor conv_common( // Get context triton::driver::context* ctx = stream->context(); // Get configuration - conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty}; + bool has_bias = torchbias.storage().size() > 0; + conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias}; triton::dnn::conv* configuration; if(m_config.find(key) == m_config.end()) configuration = m_config.emplace(key, new triton::dnn::conv( B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, - pad_d, pad_h, pad_w, ty)).first->second.get(); + pad_d, pad_h, pad_w, + 1, 1, 1, + ty, has_bias)).first->second.get(); else configuration = m_config.at(key).get(); // Get JIT @@ -55,12 +58,16 @@ torch::Tensor conv_common( jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get(); std::string src = configuration->src(); jit->add_module("conv", src.c_str(), configuration->default_params()); + triton::driver::kernel* kernel = jit->get_function("conv"); + configuration->init(stream, (triton::driver::cu_module*)kernel->module()); } else jit = m_jit.at(key).get(); // Get memory triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); + triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false); + triton::driver::buffer* bias = has_bias ? &cubias : nullptr; // Allocate output std::vector c_shapes = configuration->c_shapes(); torch::Tensor torchc; @@ -76,10 +83,9 @@ torch::Tensor conv_common( unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; // launch info - configuration->init(stream, (triton::driver::cu_module*)kernel->module()); unsigned nthreads = info.num_threads; std::array grid = configuration->get_grid(TM, TN); - configuration->set_arg(kernel, &a, &b, &c); + configuration->set_arg(kernel, &a, &b, &c, bias); stream->enqueue(kernel, grid, {nthreads, 1, 1}); return torchc; } @@ -87,6 +93,8 @@ torch::Tensor conv_common( torch::Tensor conv_fprop( const torch::Tensor data, const torch::Tensor weight, + const torch::Tensor bias, + int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w) { // Check CHECK_INPUT(data); @@ -104,16 +112,19 @@ torch::Tensor conv_fprop( const int32_t S = weight.size(2); const int32_t NF = weight.size(3); // Configuration - const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t stride_d = 1; const int32_t pad_d = 0; // Check AT_CHECK(Ci == Cf, "Number of channels in data and weights must match"); - return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight); + return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight, bias); } torch::Tensor conv_bprop( const torch::Tensor derror, const torch::Tensor weight, + const torch::Tensor bias, + int64_t H, int64_t W, + int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w){ // Check CHECK_INPUT(derror); @@ -131,22 +142,20 @@ torch::Tensor conv_bprop( const int32_t S = weight.size(2); const int32_t Kw = weight.size(3); // Compute M, P, Q - const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; - const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t stride_d = 1; int32_t pad_d = 0; - const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d; - const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h; - const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w; - - + int32_t D = 1; // Check AT_CHECK(Ki == Kw, "Number of channels in error and weights must match"); - return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight); + return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight, bias); } torch::Tensor conv_wgrad( const torch::Tensor data, const torch::Tensor derror, + const torch::Tensor bias, + int64_t R, int64_t S, + int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w ){ // Check @@ -166,16 +175,12 @@ torch::Tensor conv_wgrad( const int32_t Q = derror.size(3); // Compute M, P, Q const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; - const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t stride_d = 1; const int32_t pad_d = 0; - const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d; - const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h; - const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w; - - + const int32_t T = 1; // Check AT_CHECK(Ba == Bb, "Number of channels in error and weights must match"); - return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror); + return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror, bias); } static auto registry = diff --git a/examples/python/pytorch/test.py b/examples/python/pytorch/test.py index 5086f8a5d..787d6634b 100644 --- a/examples/python/pytorch/test.py +++ b/examples/python/pytorch/test.py @@ -2,6 +2,7 @@ import torch import triton x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) +bias = torch.autograd.Variable(torch.randn(64).cuda(), requires_grad=True) w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True) cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True) y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True) @@ -10,13 +11,16 @@ def run(x, w, conv): y = conv(x, w) loss = (y - y_target).norm(2) loss.backward() - return loss, y.clone(), x.grad.clone(), w.grad.clone() + return loss, y.clone(), x.grad.clone(), w.grad.clone(), bias.grad.clone() -ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0)) +ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (0,0))) x.grad.zero_() w.grad.zero_() -culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0)) +bias.grad.zero_() +culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=0)) +print(ttdx[0,0,:,:], cudx[0,0,:,:]) print((tty - cuy).norm(2)) print((ttdx - cudx).norm(2)) print((ttdw.permute(3,0,1,2) - cudw).norm(2)) +print((ttbias - cubias).norm(2)) diff --git a/examples/python/pytorch/triton.py b/examples/python/pytorch/triton.py index 57e17d515..ec7c86695 100644 --- a/examples/python/pytorch/triton.py +++ b/examples/python/pytorch/triton.py @@ -1,4 +1,5 @@ import torch +from torch.nn.modules.utils import _single, _pair, _triple import math torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so") @@ -6,37 +7,56 @@ torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/ class ConvFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, padding): - ctx.save_for_backward(input, weight) + def forward(ctx, input, weight, bias, stride, padding): + if bias is None: + bias = torch.empty(0) + ctx.save_for_backward(input, weight, bias) + ctx.stride = stride ctx.padding = padding - output = torch.ops.triton.conv_fprop(input, weight, padding, padding) + output = torch.ops.triton.conv_fprop(input, weight, bias, stride[0], stride[1], padding[0], padding[1]) return output @staticmethod def backward(ctx, grad_output): - input, weight = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors + stride = ctx.stride padding = ctx.padding - grad_input = grad_weight = None + grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: - grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding) + grad_input = torch.ops.triton.conv_bprop(grad_output, weight, bias, input.shape[2], input.shape[3], stride[0], stride[1], padding[0], padding[1]) if ctx.needs_input_grad[1]: - grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding) - return grad_input, grad_weight, None + grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, bias, weight.shape[1], weight.shape[2], stride[0], stride[1], padding[0], padding[1]) + if ctx.needs_input_grad[2]: + grad_bias = torch.sum(grad_output, (0, 2, 3)) + return grad_input, grad_weight, grad_bias, None, None -class Conv2d(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, padding = 0): - super(Conv2d, self).__init__() +class _ConvNd(torch.nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, groups, bias): + super(_ConvNd, self).__init__() + # not everything is supported by Triton + assert all(x==1 for x in stride) + assert all(x==1 for x in dilation) + assert transposed == False + assert all(x==0 for x in output_padding) + assert groups == 1 + # initialize self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.padding = padding self.weight = torch.nn.Parameter(torch.Tensor( in_channels, kernel_size[0], kernel_size[1], out_channels)) + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) self.reset_parameters() def forward(self, input): - return ConvFunction.apply(input, self.weight, self.padding) + return ConvFunction.apply(input, self.weight, self.bias, self.padding) def reset_parameters(self): n = self.in_channels @@ -44,3 +64,19 @@ class Conv2d(torch.nn.Module): n *= k stdv = 1. / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + + +class Conv2d(_ConvNd): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(Conv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias) diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index d01007aa6..313065fc6 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -22,7 +22,8 @@ public: int T, int R, int S, int NF, int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, - type ty = FPROP); + int upsample_d, int upsample_h, int upsample_w, + type ty = FPROP, bool bias = false); // accessors size_t a_size(); @@ -36,7 +37,8 @@ public: void init(driver::stream *stream, driver::cu_module *module); std::array get_grid(size_t TM, size_t TN); void set_arg(driver::kernel *kernel, - driver::buffer *a, driver::buffer *b, driver::buffer *c); + driver::buffer *a, driver::buffer *b, driver::buffer *c, + driver::buffer *bias); // utilities size_t get_nflops(); @@ -81,6 +83,7 @@ public: void conv(read_only restrict fp32 *a, read_only restrict fp32 *b, fp32 *c, + fp32 *bias, int32 M, int32 N, int32 K, int32 AH, int32 AW, int32 BH, int32 BW, @@ -88,7 +91,9 @@ public: int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, - int32 pad_h, int32 pad_w)"; + int32 pad_h, int32 pad_w, + int32 stride_h, int32 stride_w, + int32 upsample_h, int32 upsample_w)"; if(!is_a_deltas_cst) res += ", int32* delta"; if(is_wgrad && !is_b_deltas_cst_) @@ -103,9 +108,11 @@ public: fp32 C[TM, TN] = 0; int32 ldlut = )" + std::to_string(Fs_) + R"(; int32 rabh[TM] = rxa / CW; - int32 raw[TM] = rxa % CW - pad_w; + int32 raw[TM] = rxa % CW; int32 rab[TM] = rabh / CH; - int32 rah[TM] = rabh % CH - pad_h; + int32 rah[TM] = rabh % CH; + raw = raw*stride_w - pad_w; + rah = rah*stride_h - pad_h; int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; @@ -173,7 +180,14 @@ public: int32 rcn[TM] = rxc / (CH*CW); int32 rcpq[TM] = rxc % (CH*CW); int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q; - fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; + fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];)"; + if(bias_ && ty_==FPROP){ + res += R"( + fp32* pbias[TN] = bias + rc1; + fp32 bias[TN] = *pbias; + C = C + bias[newaxis, :];)"; + } + res += R"( int1 checkc0[TM] = rxc < M; int1 checkc1[TN] = rc1 < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; @@ -208,18 +222,18 @@ private: int32_t CD_; int32_t CH_; int32_t CW_; - // upsampling - int32_t upsample_d_; - int32_t upsample_h_; - int32_t upsample_w_; - // padding - int32_t pad_d_; - int32_t pad_h_; - int32_t pad_w_; // striding int32_t stride_d_; int32_t stride_h_; int32_t stride_w_; + // padding + int32_t pad_d_; + int32_t pad_h_; + int32_t pad_w_; + // upsampling + int32_t upsample_d_; + int32_t upsample_h_; + int32_t upsample_w_; // equivalent matmul int32_t M_; int32_t N_; @@ -249,6 +263,7 @@ private: bool is_mask_cst_; // type type ty_; + bool bias_; bool b_trans_; bool b_lut_; }; diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 621f4f14a..fc97acbba 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -8,16 +8,18 @@ conv::conv(int B, int NC, int T, int R, int S, int NF, int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, - type ty) + int upsample_d, int upsample_h, int upsample_w, + type ty, bool bias) : NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w), - upsample_d_(1), upsample_h_(1), upsample_w_(1), pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), - ty_(ty) + upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), + ty_(ty), bias_(bias) { CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_; CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_; + // shapes shapes_a_ = {NB_, NC_, AD_, AH_, AW_}; shapes_b_ = {NC_, BD_, BH_, BW_, NF_}; @@ -232,65 +234,70 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) { } void conv::set_arg(driver::kernel *kernel, - driver::buffer *a, driver::buffer *b, driver::buffer *c) + driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias) { kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); - kernel->setArg(3, M_); - kernel->setArg(4, N_); - kernel->setArg(5, K_); - kernel->setArg(6, AH_); - kernel->setArg(7, AW_); - kernel->setArg(8, BH_); - kernel->setArg(9, BW_); - kernel->setArg(10, CH_); - kernel->setArg(11, CW_); + kernel->setArg(3, bias); + kernel->setArg(4, M_); + kernel->setArg(5, N_); + kernel->setArg(6, K_); + kernel->setArg(7, AH_); + kernel->setArg(8, AW_); + kernel->setArg(9, BH_); + kernel->setArg(10, BW_); + kernel->setArg(11, CH_); + kernel->setArg(12, CW_); // A arguments if(ty_ == WGRAD){ - kernel->setArg(12, ld_a_[1]); - kernel->setArg(13, ld_a_[0]); + kernel->setArg(13, ld_a_[1]); + kernel->setArg(14, ld_a_[0]); } else{ - kernel->setArg(12, ld_a_[0]); - kernel->setArg(13, ld_a_[1]); + kernel->setArg(13, ld_a_[0]); + kernel->setArg(14, ld_a_[1]); } - kernel->setArg(14, ld_a_[2]); - kernel->setArg(15, ld_a_[3]); - kernel->setArg(16, ld_a_[4]); + kernel->setArg(15, ld_a_[2]); + kernel->setArg(16, ld_a_[3]); + kernel->setArg(17, ld_a_[4]); // B arguments if(ty_ == WGRAD){ - kernel->setArg(17, ld_b_[0]); - kernel->setArg(18, ld_b_[2]); - kernel->setArg(19, ld_b_[3]); - kernel->setArg(20, ld_b_[4]); - kernel->setArg(21, ld_b_[1]); - } - else{ - kernel->setArg(17, ld_b_[0]); - kernel->setArg(18, ld_b_[1]); + kernel->setArg(18, ld_b_[0]); kernel->setArg(19, ld_b_[2]); kernel->setArg(20, ld_b_[3]); kernel->setArg(21, ld_b_[4]); + kernel->setArg(22, ld_b_[1]); + } + else{ + kernel->setArg(18, ld_b_[0]); + kernel->setArg(19, ld_b_[1]); + kernel->setArg(20, ld_b_[2]); + kernel->setArg(21, ld_b_[3]); + kernel->setArg(22, ld_b_[4]); } // C arguments if(ty_ == WGRAD){ - kernel->setArg(22, ld_c_[0]); - kernel->setArg(23, ld_c_[4]); + kernel->setArg(23, ld_c_[0]); + kernel->setArg(24, ld_c_[4]); + kernel->setArg(25, ld_c_[1]); + kernel->setArg(26, ld_c_[2]); + kernel->setArg(27, ld_c_[3]); + } + else{ + kernel->setArg(23, ld_c_[0]); kernel->setArg(24, ld_c_[1]); kernel->setArg(25, ld_c_[2]); kernel->setArg(26, ld_c_[3]); + kernel->setArg(27, ld_c_[4]); } - else{ - kernel->setArg(22, ld_c_[0]); - kernel->setArg(23, ld_c_[1]); - kernel->setArg(24, ld_c_[2]); - kernel->setArg(25, ld_c_[3]); - kernel->setArg(26, ld_c_[4]); - } - kernel->setArg(27, pad_h_); - kernel->setArg(28, pad_w_); - size_t idx = 29; + kernel->setArg(28, pad_h_); + kernel->setArg(29, pad_w_); + kernel->setArg(30, stride_h_); + kernel->setArg(31, stride_w_); + kernel->setArg(32, upsample_h_); + kernel->setArg(33, upsample_w_); + size_t idx = 34; if(!is_a_deltas_cst) kernel->setArg(idx++, d_a_deltas_); if(!is_b_deltas_cst_) diff --git a/lib/driver/kernel.cpp b/lib/driver/kernel.cpp index 81c797047..a16e3e6f9 100755 --- a/lib/driver/kernel.cpp +++ b/lib/driver/kernel.cpp @@ -81,7 +81,10 @@ void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){ } void host_kernel::setArg(unsigned int index, driver::buffer* buffer){ - kernel::setArg(index, (void*)buffer->hst()->data); + if(buffer) + kernel::setArg(index, (void*)buffer->hst()->data); + else + kernel::setArg(index, (std::ptrdiff_t)0); } const std::vector &host_kernel::params(){ @@ -106,7 +109,10 @@ void ocl_kernel::setArg(unsigned int index, std::size_t size, void* ptr) { } void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) { - check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl())); + if(buffer) + check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl())); + else + kernel::setArg(index, (std::ptrdiff_t)0); } @@ -130,8 +136,12 @@ void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){ cu_params_[index] = cu_params_store_[index].get(); } -void cu_kernel::setArg(unsigned int index, driver::buffer* data) -{ return kernel::setArg(index, *data->cu());} +void cu_kernel::setArg(unsigned int index, driver::buffer* data){ + if(data) + kernel::setArg(index, *data->cu()); + else + kernel::setArg(index, (std::ptrdiff_t)0); +} void* const* cu_kernel::cu_params() const { return cu_params_.data(); }