[dnn/conv] Added bias and forward stride
This commit is contained in:
@@ -10,13 +10,15 @@ int main() {
|
|||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||||
// initialization
|
// initialization
|
||||||
int32_t B = 4, NF = 32;
|
int32_t B = 4, NF = 32;
|
||||||
int32_t D = 1, H = 56, W = 56;
|
int32_t D = 1, H = 56, W = 56;
|
||||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
|
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||||
|
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||||
|
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
|
||||||
// convolution configuration
|
// convolution configuration
|
||||||
std::vector<float> hc(configuration.c_size());
|
std::vector<float> hc(configuration.c_size());
|
||||||
std::vector<float> rc(configuration.c_size());
|
std::vector<float> rc(configuration.c_size());
|
||||||
@@ -47,7 +49,7 @@ int main() {
|
|||||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
configuration.set_arg(kernel, da, db, dc);
|
configuration.set_arg(kernel, da, db, dc, nullptr);
|
||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||||
|
@@ -1,117 +1,142 @@
|
|||||||
import argparse
|
'''Train CIFAR10 with PyTorch.'''
|
||||||
import triton
|
from __future__ import print_function
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
import torch.nn.functional as F
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
torch.manual_seed(0)
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
class Net(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.conv1 = nn.Conv2d(1, 20, (5,5))
|
|
||||||
self.conv2 = nn.Conv2d(20, 50, (5,5))
|
|
||||||
self.fc1 = nn.Linear(4*4*50, 500)
|
|
||||||
self.fc2 = nn.Linear(500, 10)
|
|
||||||
|
|
||||||
def forward(self, x):
|
import os
|
||||||
x = F.relu(self.conv1(x))
|
import argparse
|
||||||
x = F.max_pool2d(x, 2, 2)
|
|
||||||
x = F.relu(self.conv2(x))
|
|
||||||
x = F.max_pool2d(x, 2, 2)
|
|
||||||
x = x.view(-1, 4*4*50)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = self.fc2(x)
|
|
||||||
return F.log_softmax(x, dim=1)
|
|
||||||
|
|
||||||
def train(args, model, device, train_loader, optimizer, epoch):
|
from resnet import *
|
||||||
model.train()
|
from utils import progress_bar
|
||||||
for batch_idx, (data, target) in enumerate(train_loader):
|
|
||||||
data, target = data.to(device), target.to(device)
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
|
||||||
|
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
||||||
|
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
best_acc = 0 # best test accuracy
|
||||||
|
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||||
|
|
||||||
|
# Data
|
||||||
|
print('==> Preparing data..')
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
|
||||||
|
transform_test = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||||
|
])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
|
||||||
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
|
||||||
|
|
||||||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
||||||
|
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
|
||||||
|
|
||||||
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||||
|
|
||||||
|
# Model
|
||||||
|
print('==> Building model..')
|
||||||
|
# net = VGG('VGG19')
|
||||||
|
net = ResNet18()
|
||||||
|
# net = PreActResNet18()
|
||||||
|
# net = GoogLeNet()
|
||||||
|
# net = DenseNet121()
|
||||||
|
# net = ResNeXt29_2x64d()
|
||||||
|
# net = MobileNet()
|
||||||
|
# net = MobileNetV2()
|
||||||
|
# net = DPN92()
|
||||||
|
# net = ShuffleNetG2()
|
||||||
|
# net = SENet18()
|
||||||
|
#net = ShuffleNetV2(1)
|
||||||
|
net = net.to(device)
|
||||||
|
if device == 'cuda':
|
||||||
|
net = torch.nn.DataParallel(net)
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
# Load checkpoint.
|
||||||
|
print('==> Resuming from checkpoint..')
|
||||||
|
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
||||||
|
checkpoint = torch.load('./checkpoint/ckpt.t7')
|
||||||
|
net.load_state_dict(checkpoint['net'])
|
||||||
|
best_acc = checkpoint['acc']
|
||||||
|
start_epoch = checkpoint['epoch']
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
def train(epoch):
|
||||||
|
print('\nEpoch: %d' % epoch)
|
||||||
|
net.train()
|
||||||
|
train_loss = 0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for batch_idx, (inputs, targets) in enumerate(trainloader):
|
||||||
|
inputs, targets = inputs.to(device), targets.to(device)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(data)
|
outputs = net(inputs)
|
||||||
loss = F.nll_loss(output, target)
|
loss = criterion(outputs, targets)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if batch_idx % args.log_interval == 0:
|
|
||||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
||||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
|
||||||
100. * batch_idx / len(train_loader), loss.item()))
|
|
||||||
|
|
||||||
def test(args, model, device, test_loader):
|
train_loss += loss.item()
|
||||||
model.eval()
|
_, predicted = outputs.max(1)
|
||||||
|
total += targets.size(0)
|
||||||
|
correct += predicted.eq(targets).sum().item()
|
||||||
|
|
||||||
|
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
|
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
||||||
|
|
||||||
|
def test(epoch):
|
||||||
|
global best_acc
|
||||||
|
net.eval()
|
||||||
test_loss = 0
|
test_loss = 0
|
||||||
correct = 0
|
correct = 0
|
||||||
|
total = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, target in test_loader:
|
for batch_idx, (inputs, targets) in enumerate(testloader):
|
||||||
data, target = data.to(device), target.to(device)
|
inputs, targets = inputs.to(device), targets.to(device)
|
||||||
output = model(data)
|
outputs = net(inputs)
|
||||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
loss = criterion(outputs, targets)
|
||||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
|
||||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
|
||||||
|
|
||||||
test_loss /= len(test_loader.dataset)
|
test_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += targets.size(0)
|
||||||
|
correct += predicted.eq(targets).sum().item()
|
||||||
|
|
||||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||||
test_loss, correct, len(test_loader.dataset),
|
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
||||||
100. * correct / len(test_loader.dataset)))
|
|
||||||
|
|
||||||
def main():
|
# Save checkpoint.
|
||||||
# Training settings
|
acc = 100.*correct/total
|
||||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
if acc > best_acc:
|
||||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
print('Saving..')
|
||||||
help='input batch size for training (default: 64)')
|
state = {
|
||||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
'net': net.state_dict(),
|
||||||
help='input batch size for testing (default: 1000)')
|
'acc': acc,
|
||||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
'epoch': epoch,
|
||||||
help='number of epochs to train (default: 10)')
|
}
|
||||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
if not os.path.isdir('checkpoint'):
|
||||||
help='learning rate (default: 0.01)')
|
os.mkdir('checkpoint')
|
||||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
torch.save(state, './checkpoint/ckpt.t7')
|
||||||
help='SGD momentum (default: 0.5)')
|
best_acc = acc
|
||||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
|
||||||
help='disables CUDA training')
|
|
||||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
|
||||||
help='random seed (default: 1)')
|
|
||||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
|
||||||
help='how many batches to wait before logging training status')
|
|
||||||
|
|
||||||
parser.add_argument('--save-model', action='store_true', default=False,
|
|
||||||
help='For Saving the current Model')
|
|
||||||
args = parser.parse_args()
|
|
||||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if use_cuda else "cpu")
|
|
||||||
|
|
||||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
datasets.MNIST('../data', train=True, download=True,
|
|
||||||
transform=transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307,), (0.3081,))
|
|
||||||
])),
|
|
||||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
|
||||||
test_loader = torch.utils.data.DataLoader(
|
|
||||||
datasets.MNIST('../data', train=False, transform=transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307,), (0.3081,))
|
|
||||||
])),
|
|
||||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
model = Net().to(device)
|
for epoch in range(start_epoch, start_epoch+200):
|
||||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
train(epoch)
|
||||||
|
test(epoch)
|
||||||
for epoch in range(1, args.epochs + 1):
|
|
||||||
train(args, model, device, train_loader, optimizer, epoch)
|
|
||||||
test(args, model, device, test_loader)
|
|
||||||
|
|
||||||
if (args.save_model):
|
|
||||||
torch.save(model.state_dict(),"mnist_cnn.pt")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
@@ -14,7 +14,7 @@ typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
|||||||
int32_t, int32_t, int32_t, int32_t,
|
int32_t, int32_t, int32_t, int32_t,
|
||||||
int32_t, int32_t, int32_t,
|
int32_t, int32_t, int32_t,
|
||||||
int32_t, int32_t, int32_t,
|
int32_t, int32_t, int32_t,
|
||||||
triton::dnn::conv::type> conv_key_t;
|
triton::dnn::conv::type, bool> conv_key_t;
|
||||||
|
|
||||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||||
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||||
@@ -26,7 +26,7 @@ torch::Tensor conv_common(
|
|||||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||||
triton::dnn::conv::type ty,
|
triton::dnn::conv::type ty,
|
||||||
torch::Tensor torcha, torch::Tensor torchb
|
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias
|
||||||
) {
|
) {
|
||||||
// Wrap CUDA handles
|
// Wrap CUDA handles
|
||||||
c10::DeviceIndex device = torcha.storage().device().index();
|
c10::DeviceIndex device = torcha.storage().device().index();
|
||||||
@@ -40,13 +40,16 @@ torch::Tensor conv_common(
|
|||||||
// Get context
|
// Get context
|
||||||
triton::driver::context* ctx = stream->context();
|
triton::driver::context* ctx = stream->context();
|
||||||
// Get configuration
|
// Get configuration
|
||||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
|
bool has_bias = torchbias.storage().size() > 0;
|
||||||
|
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
|
||||||
triton::dnn::conv* configuration;
|
triton::dnn::conv* configuration;
|
||||||
if(m_config.find(key) == m_config.end())
|
if(m_config.find(key) == m_config.end())
|
||||||
configuration = m_config.emplace(key, new triton::dnn::conv(
|
configuration = m_config.emplace(key, new triton::dnn::conv(
|
||||||
B, C, D, H, W, T, R, S, NF,
|
B, C, D, H, W, T, R, S, NF,
|
||||||
stride_d, stride_h, stride_w,
|
stride_d, stride_h, stride_w,
|
||||||
pad_d, pad_h, pad_w, ty)).first->second.get();
|
pad_d, pad_h, pad_w,
|
||||||
|
1, 1, 1,
|
||||||
|
ty, has_bias)).first->second.get();
|
||||||
else
|
else
|
||||||
configuration = m_config.at(key).get();
|
configuration = m_config.at(key).get();
|
||||||
// Get JIT
|
// Get JIT
|
||||||
@@ -55,12 +58,16 @@ torch::Tensor conv_common(
|
|||||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||||
std::string src = configuration->src();
|
std::string src = configuration->src();
|
||||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||||
|
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||||
|
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
jit = m_jit.at(key).get();
|
jit = m_jit.at(key).get();
|
||||||
// Get memory
|
// Get memory
|
||||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||||
|
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||||
|
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||||
// Allocate output
|
// Allocate output
|
||||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||||
torch::Tensor torchc;
|
torch::Tensor torchc;
|
||||||
@@ -76,10 +83,9 @@ torch::Tensor conv_common(
|
|||||||
unsigned TM = info.global_range_size[0];
|
unsigned TM = info.global_range_size[0];
|
||||||
unsigned TN = info.global_range_size[1];
|
unsigned TN = info.global_range_size[1];
|
||||||
// launch info
|
// launch info
|
||||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
|
||||||
unsigned nthreads = info.num_threads;
|
unsigned nthreads = info.num_threads;
|
||||||
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
||||||
configuration->set_arg(kernel, &a, &b, &c);
|
configuration->set_arg(kernel, &a, &b, &c, bias);
|
||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
return torchc;
|
return torchc;
|
||||||
}
|
}
|
||||||
@@ -87,6 +93,8 @@ torch::Tensor conv_common(
|
|||||||
torch::Tensor conv_fprop(
|
torch::Tensor conv_fprop(
|
||||||
const torch::Tensor data,
|
const torch::Tensor data,
|
||||||
const torch::Tensor weight,
|
const torch::Tensor weight,
|
||||||
|
const torch::Tensor bias,
|
||||||
|
int64_t stride_h, int64_t stride_w,
|
||||||
int64_t pad_h, int64_t pad_w) {
|
int64_t pad_h, int64_t pad_w) {
|
||||||
// Check
|
// Check
|
||||||
CHECK_INPUT(data);
|
CHECK_INPUT(data);
|
||||||
@@ -104,16 +112,19 @@ torch::Tensor conv_fprop(
|
|||||||
const int32_t S = weight.size(2);
|
const int32_t S = weight.size(2);
|
||||||
const int32_t NF = weight.size(3);
|
const int32_t NF = weight.size(3);
|
||||||
// Configuration
|
// Configuration
|
||||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
const int32_t stride_d = 1;
|
||||||
const int32_t pad_d = 0;
|
const int32_t pad_d = 0;
|
||||||
// Check
|
// Check
|
||||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor conv_bprop(
|
torch::Tensor conv_bprop(
|
||||||
const torch::Tensor derror,
|
const torch::Tensor derror,
|
||||||
const torch::Tensor weight,
|
const torch::Tensor weight,
|
||||||
|
const torch::Tensor bias,
|
||||||
|
int64_t H, int64_t W,
|
||||||
|
int64_t stride_h, int64_t stride_w,
|
||||||
int64_t pad_h, int64_t pad_w){
|
int64_t pad_h, int64_t pad_w){
|
||||||
// Check
|
// Check
|
||||||
CHECK_INPUT(derror);
|
CHECK_INPUT(derror);
|
||||||
@@ -131,22 +142,20 @@ torch::Tensor conv_bprop(
|
|||||||
const int32_t S = weight.size(2);
|
const int32_t S = weight.size(2);
|
||||||
const int32_t Kw = weight.size(3);
|
const int32_t Kw = weight.size(3);
|
||||||
// Compute M, P, Q
|
// Compute M, P, Q
|
||||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
const int32_t stride_d = 1;
|
||||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
|
||||||
int32_t pad_d = 0;
|
int32_t pad_d = 0;
|
||||||
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
|
int32_t D = 1;
|
||||||
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
|
|
||||||
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
|
|
||||||
|
|
||||||
|
|
||||||
// Check
|
// Check
|
||||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor conv_wgrad(
|
torch::Tensor conv_wgrad(
|
||||||
const torch::Tensor data,
|
const torch::Tensor data,
|
||||||
const torch::Tensor derror,
|
const torch::Tensor derror,
|
||||||
|
const torch::Tensor bias,
|
||||||
|
int64_t R, int64_t S,
|
||||||
|
int64_t stride_h, int64_t stride_w,
|
||||||
int64_t pad_h, int64_t pad_w
|
int64_t pad_h, int64_t pad_w
|
||||||
){
|
){
|
||||||
// Check
|
// Check
|
||||||
@@ -166,16 +175,12 @@ torch::Tensor conv_wgrad(
|
|||||||
const int32_t Q = derror.size(3);
|
const int32_t Q = derror.size(3);
|
||||||
// Compute M, P, Q
|
// Compute M, P, Q
|
||||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
const int32_t stride_d = 1;
|
||||||
const int32_t pad_d = 0;
|
const int32_t pad_d = 0;
|
||||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
|
const int32_t T = 1;
|
||||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
|
|
||||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
|
|
||||||
|
|
||||||
|
|
||||||
// Check
|
// Check
|
||||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto registry =
|
static auto registry =
|
||||||
|
@@ -2,6 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||||
|
bias = torch.autograd.Variable(torch.randn(64).cuda(), requires_grad=True)
|
||||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||||||
@@ -10,13 +11,16 @@ def run(x, w, conv):
|
|||||||
y = conv(x, w)
|
y = conv(x, w)
|
||||||
loss = (y - y_target).norm(2)
|
loss = (y - y_target).norm(2)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
return loss, y.clone(), x.grad.clone(), w.grad.clone(), bias.grad.clone()
|
||||||
|
|
||||||
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
|
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (0,0)))
|
||||||
x.grad.zero_()
|
x.grad.zero_()
|
||||||
w.grad.zero_()
|
w.grad.zero_()
|
||||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
|
bias.grad.zero_()
|
||||||
|
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=0))
|
||||||
|
|
||||||
|
print(ttdx[0,0,:,:], cudx[0,0,:,:])
|
||||||
print((tty - cuy).norm(2))
|
print((tty - cuy).norm(2))
|
||||||
print((ttdx - cudx).norm(2))
|
print((ttdx - cudx).norm(2))
|
||||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||||
|
print((ttbias - cubias).norm(2))
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch.nn.modules.utils import _single, _pair, _triple
|
||||||
import math
|
import math
|
||||||
|
|
||||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||||
@@ -6,37 +7,56 @@ torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/
|
|||||||
class ConvFunction(torch.autograd.Function):
|
class ConvFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, padding):
|
def forward(ctx, input, weight, bias, stride, padding):
|
||||||
ctx.save_for_backward(input, weight)
|
if bias is None:
|
||||||
|
bias = torch.empty(0)
|
||||||
|
ctx.save_for_backward(input, weight, bias)
|
||||||
|
ctx.stride = stride
|
||||||
ctx.padding = padding
|
ctx.padding = padding
|
||||||
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
|
output = torch.ops.triton.conv_fprop(input, weight, bias, stride[0], stride[1], padding[0], padding[1])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight = ctx.saved_tensors
|
input, weight, bias = ctx.saved_tensors
|
||||||
|
stride = ctx.stride
|
||||||
padding = ctx.padding
|
padding = ctx.padding
|
||||||
grad_input = grad_weight = None
|
grad_input = grad_weight = grad_bias = None
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
|
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, bias, input.shape[2], input.shape[3], stride[0], stride[1], padding[0], padding[1])
|
||||||
if ctx.needs_input_grad[1]:
|
if ctx.needs_input_grad[1]:
|
||||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
|
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, bias, weight.shape[1], weight.shape[2], stride[0], stride[1], padding[0], padding[1])
|
||||||
return grad_input, grad_weight, None
|
if ctx.needs_input_grad[2]:
|
||||||
|
grad_bias = torch.sum(grad_output, (0, 2, 3))
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class _ConvNd(torch.nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
|
|
||||||
super(Conv2d, self).__init__()
|
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
||||||
|
padding, dilation, transposed, output_padding, groups, bias):
|
||||||
|
super(_ConvNd, self).__init__()
|
||||||
|
# not everything is supported by Triton
|
||||||
|
assert all(x==1 for x in stride)
|
||||||
|
assert all(x==1 for x in dilation)
|
||||||
|
assert transposed == False
|
||||||
|
assert all(x==0 for x in output_padding)
|
||||||
|
assert groups == 1
|
||||||
|
# initialize
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.weight = torch.nn.Parameter(torch.Tensor(
|
self.weight = torch.nn.Parameter(torch.Tensor(
|
||||||
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return ConvFunction.apply(input, self.weight, self.padding)
|
return ConvFunction.apply(input, self.weight, self.bias, self.padding)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
n = self.in_channels
|
n = self.in_channels
|
||||||
@@ -44,3 +64,19 @@ class Conv2d(torch.nn.Module):
|
|||||||
n *= k
|
n *= k
|
||||||
stdv = 1. / math.sqrt(n)
|
stdv = 1. / math.sqrt(n)
|
||||||
self.weight.data.uniform_(-stdv, stdv)
|
self.weight.data.uniform_(-stdv, stdv)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d(_ConvNd):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||||
|
padding=0, dilation=1, groups=1, bias=True):
|
||||||
|
kernel_size = _pair(kernel_size)
|
||||||
|
stride = _pair(stride)
|
||||||
|
padding = _pair(padding)
|
||||||
|
dilation = _pair(dilation)
|
||||||
|
super(Conv2d, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
|
False, _pair(0), groups, bias)
|
||||||
|
@@ -22,7 +22,8 @@ public:
|
|||||||
int T, int R, int S, int NF,
|
int T, int R, int S, int NF,
|
||||||
int stride_d, int stride_h, int stride_w,
|
int stride_d, int stride_h, int stride_w,
|
||||||
int pad_d, int pad_h, int pad_w,
|
int pad_d, int pad_h, int pad_w,
|
||||||
type ty = FPROP);
|
int upsample_d, int upsample_h, int upsample_w,
|
||||||
|
type ty = FPROP, bool bias = false);
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
size_t a_size();
|
size_t a_size();
|
||||||
@@ -36,7 +37,8 @@ public:
|
|||||||
void init(driver::stream *stream, driver::cu_module *module);
|
void init(driver::stream *stream, driver::cu_module *module);
|
||||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||||
void set_arg(driver::kernel *kernel,
|
void set_arg(driver::kernel *kernel,
|
||||||
driver::buffer *a, driver::buffer *b, driver::buffer *c);
|
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||||
|
driver::buffer *bias);
|
||||||
|
|
||||||
// utilities
|
// utilities
|
||||||
size_t get_nflops();
|
size_t get_nflops();
|
||||||
@@ -81,6 +83,7 @@ public:
|
|||||||
void conv(read_only restrict fp32 *a,
|
void conv(read_only restrict fp32 *a,
|
||||||
read_only restrict fp32 *b,
|
read_only restrict fp32 *b,
|
||||||
fp32 *c,
|
fp32 *c,
|
||||||
|
fp32 *bias,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 AH, int32 AW,
|
int32 AH, int32 AW,
|
||||||
int32 BH, int32 BW,
|
int32 BH, int32 BW,
|
||||||
@@ -88,7 +91,9 @@ public:
|
|||||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||||
int32 pad_h, int32 pad_w)";
|
int32 pad_h, int32 pad_w,
|
||||||
|
int32 stride_h, int32 stride_w,
|
||||||
|
int32 upsample_h, int32 upsample_w)";
|
||||||
if(!is_a_deltas_cst)
|
if(!is_a_deltas_cst)
|
||||||
res += ", int32* delta";
|
res += ", int32* delta";
|
||||||
if(is_wgrad && !is_b_deltas_cst_)
|
if(is_wgrad && !is_b_deltas_cst_)
|
||||||
@@ -103,9 +108,11 @@ public:
|
|||||||
fp32 C[TM, TN] = 0;
|
fp32 C[TM, TN] = 0;
|
||||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||||
int32 rabh[TM] = rxa / CW;
|
int32 rabh[TM] = rxa / CW;
|
||||||
int32 raw[TM] = rxa % CW - pad_w;
|
int32 raw[TM] = rxa % CW;
|
||||||
int32 rab[TM] = rabh / CH;
|
int32 rab[TM] = rabh / CH;
|
||||||
int32 rah[TM] = rabh % CH - pad_h;
|
int32 rah[TM] = rabh % CH;
|
||||||
|
raw = raw*stride_w - pad_w;
|
||||||
|
rah = rah*stride_h - pad_h;
|
||||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||||
@@ -173,7 +180,14 @@ public:
|
|||||||
int32 rcn[TM] = rxc / (CH*CW);
|
int32 rcn[TM] = rxc / (CH*CW);
|
||||||
int32 rcpq[TM] = rxc % (CH*CW);
|
int32 rcpq[TM] = rxc % (CH*CW);
|
||||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];)";
|
||||||
|
if(bias_ && ty_==FPROP){
|
||||||
|
res += R"(
|
||||||
|
fp32* pbias[TN] = bias + rc1;
|
||||||
|
fp32 bias[TN] = *pbias;
|
||||||
|
C = C + bias[newaxis, :];)";
|
||||||
|
}
|
||||||
|
res += R"(
|
||||||
int1 checkc0[TM] = rxc < M;
|
int1 checkc0[TM] = rxc < M;
|
||||||
int1 checkc1[TN] = rc1 < N;
|
int1 checkc1[TN] = rc1 < N;
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
@@ -208,18 +222,18 @@ private:
|
|||||||
int32_t CD_;
|
int32_t CD_;
|
||||||
int32_t CH_;
|
int32_t CH_;
|
||||||
int32_t CW_;
|
int32_t CW_;
|
||||||
// upsampling
|
|
||||||
int32_t upsample_d_;
|
|
||||||
int32_t upsample_h_;
|
|
||||||
int32_t upsample_w_;
|
|
||||||
// padding
|
|
||||||
int32_t pad_d_;
|
|
||||||
int32_t pad_h_;
|
|
||||||
int32_t pad_w_;
|
|
||||||
// striding
|
// striding
|
||||||
int32_t stride_d_;
|
int32_t stride_d_;
|
||||||
int32_t stride_h_;
|
int32_t stride_h_;
|
||||||
int32_t stride_w_;
|
int32_t stride_w_;
|
||||||
|
// padding
|
||||||
|
int32_t pad_d_;
|
||||||
|
int32_t pad_h_;
|
||||||
|
int32_t pad_w_;
|
||||||
|
// upsampling
|
||||||
|
int32_t upsample_d_;
|
||||||
|
int32_t upsample_h_;
|
||||||
|
int32_t upsample_w_;
|
||||||
// equivalent matmul
|
// equivalent matmul
|
||||||
int32_t M_;
|
int32_t M_;
|
||||||
int32_t N_;
|
int32_t N_;
|
||||||
@@ -249,6 +263,7 @@ private:
|
|||||||
bool is_mask_cst_;
|
bool is_mask_cst_;
|
||||||
// type
|
// type
|
||||||
type ty_;
|
type ty_;
|
||||||
|
bool bias_;
|
||||||
bool b_trans_;
|
bool b_trans_;
|
||||||
bool b_lut_;
|
bool b_lut_;
|
||||||
};
|
};
|
||||||
|
@@ -8,16 +8,18 @@ conv::conv(int B, int NC,
|
|||||||
int T, int R, int S, int NF,
|
int T, int R, int S, int NF,
|
||||||
int stride_d, int stride_h, int stride_w,
|
int stride_d, int stride_h, int stride_w,
|
||||||
int pad_d, int pad_h, int pad_w,
|
int pad_d, int pad_h, int pad_w,
|
||||||
type ty)
|
int upsample_d, int upsample_h, int upsample_w,
|
||||||
|
type ty, bool bias)
|
||||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||||
upsample_d_(1), upsample_h_(1), upsample_w_(1),
|
|
||||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||||
ty_(ty)
|
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||||
|
ty_(ty), bias_(bias)
|
||||||
{
|
{
|
||||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||||
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
CH_ = (AH_*upsample_h_ - BH_ + 1 + 2*pad_h_ + stride_h_ - 1)/stride_h_;
|
||||||
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
CW_ = (AW_*upsample_w_ - BW_ + 1 + 2*pad_w_ + stride_w_ - 1)/stride_w_;
|
||||||
|
|
||||||
// shapes
|
// shapes
|
||||||
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
|
shapes_a_ = {NB_, NC_, AD_, AH_, AW_};
|
||||||
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
|
shapes_b_ = {NC_, BD_, BH_, BW_, NF_};
|
||||||
@@ -232,65 +234,70 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void conv::set_arg(driver::kernel *kernel,
|
void conv::set_arg(driver::kernel *kernel,
|
||||||
driver::buffer *a, driver::buffer *b, driver::buffer *c)
|
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias)
|
||||||
{
|
{
|
||||||
kernel->setArg(0, a);
|
kernel->setArg(0, a);
|
||||||
kernel->setArg(1, b);
|
kernel->setArg(1, b);
|
||||||
kernel->setArg(2, c);
|
kernel->setArg(2, c);
|
||||||
kernel->setArg(3, M_);
|
kernel->setArg(3, bias);
|
||||||
kernel->setArg(4, N_);
|
kernel->setArg(4, M_);
|
||||||
kernel->setArg(5, K_);
|
kernel->setArg(5, N_);
|
||||||
kernel->setArg(6, AH_);
|
kernel->setArg(6, K_);
|
||||||
kernel->setArg(7, AW_);
|
kernel->setArg(7, AH_);
|
||||||
kernel->setArg(8, BH_);
|
kernel->setArg(8, AW_);
|
||||||
kernel->setArg(9, BW_);
|
kernel->setArg(9, BH_);
|
||||||
kernel->setArg(10, CH_);
|
kernel->setArg(10, BW_);
|
||||||
kernel->setArg(11, CW_);
|
kernel->setArg(11, CH_);
|
||||||
|
kernel->setArg(12, CW_);
|
||||||
// A arguments
|
// A arguments
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD){
|
||||||
kernel->setArg(12, ld_a_[1]);
|
kernel->setArg(13, ld_a_[1]);
|
||||||
kernel->setArg(13, ld_a_[0]);
|
kernel->setArg(14, ld_a_[0]);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
kernel->setArg(12, ld_a_[0]);
|
kernel->setArg(13, ld_a_[0]);
|
||||||
kernel->setArg(13, ld_a_[1]);
|
kernel->setArg(14, ld_a_[1]);
|
||||||
}
|
}
|
||||||
kernel->setArg(14, ld_a_[2]);
|
kernel->setArg(15, ld_a_[2]);
|
||||||
kernel->setArg(15, ld_a_[3]);
|
kernel->setArg(16, ld_a_[3]);
|
||||||
kernel->setArg(16, ld_a_[4]);
|
kernel->setArg(17, ld_a_[4]);
|
||||||
// B arguments
|
// B arguments
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD){
|
||||||
kernel->setArg(17, ld_b_[0]);
|
kernel->setArg(18, ld_b_[0]);
|
||||||
kernel->setArg(18, ld_b_[2]);
|
|
||||||
kernel->setArg(19, ld_b_[3]);
|
|
||||||
kernel->setArg(20, ld_b_[4]);
|
|
||||||
kernel->setArg(21, ld_b_[1]);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
kernel->setArg(17, ld_b_[0]);
|
|
||||||
kernel->setArg(18, ld_b_[1]);
|
|
||||||
kernel->setArg(19, ld_b_[2]);
|
kernel->setArg(19, ld_b_[2]);
|
||||||
kernel->setArg(20, ld_b_[3]);
|
kernel->setArg(20, ld_b_[3]);
|
||||||
kernel->setArg(21, ld_b_[4]);
|
kernel->setArg(21, ld_b_[4]);
|
||||||
|
kernel->setArg(22, ld_b_[1]);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
kernel->setArg(18, ld_b_[0]);
|
||||||
|
kernel->setArg(19, ld_b_[1]);
|
||||||
|
kernel->setArg(20, ld_b_[2]);
|
||||||
|
kernel->setArg(21, ld_b_[3]);
|
||||||
|
kernel->setArg(22, ld_b_[4]);
|
||||||
}
|
}
|
||||||
// C arguments
|
// C arguments
|
||||||
if(ty_ == WGRAD){
|
if(ty_ == WGRAD){
|
||||||
kernel->setArg(22, ld_c_[0]);
|
kernel->setArg(23, ld_c_[0]);
|
||||||
kernel->setArg(23, ld_c_[4]);
|
kernel->setArg(24, ld_c_[4]);
|
||||||
|
kernel->setArg(25, ld_c_[1]);
|
||||||
|
kernel->setArg(26, ld_c_[2]);
|
||||||
|
kernel->setArg(27, ld_c_[3]);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
kernel->setArg(23, ld_c_[0]);
|
||||||
kernel->setArg(24, ld_c_[1]);
|
kernel->setArg(24, ld_c_[1]);
|
||||||
kernel->setArg(25, ld_c_[2]);
|
kernel->setArg(25, ld_c_[2]);
|
||||||
kernel->setArg(26, ld_c_[3]);
|
kernel->setArg(26, ld_c_[3]);
|
||||||
|
kernel->setArg(27, ld_c_[4]);
|
||||||
}
|
}
|
||||||
else{
|
kernel->setArg(28, pad_h_);
|
||||||
kernel->setArg(22, ld_c_[0]);
|
kernel->setArg(29, pad_w_);
|
||||||
kernel->setArg(23, ld_c_[1]);
|
kernel->setArg(30, stride_h_);
|
||||||
kernel->setArg(24, ld_c_[2]);
|
kernel->setArg(31, stride_w_);
|
||||||
kernel->setArg(25, ld_c_[3]);
|
kernel->setArg(32, upsample_h_);
|
||||||
kernel->setArg(26, ld_c_[4]);
|
kernel->setArg(33, upsample_w_);
|
||||||
}
|
size_t idx = 34;
|
||||||
kernel->setArg(27, pad_h_);
|
|
||||||
kernel->setArg(28, pad_w_);
|
|
||||||
size_t idx = 29;
|
|
||||||
if(!is_a_deltas_cst)
|
if(!is_a_deltas_cst)
|
||||||
kernel->setArg(idx++, d_a_deltas_);
|
kernel->setArg(idx++, d_a_deltas_);
|
||||||
if(!is_b_deltas_cst_)
|
if(!is_b_deltas_cst_)
|
||||||
|
@@ -81,7 +81,10 @@ void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
|||||||
}
|
}
|
||||||
|
|
||||||
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
|
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
|
||||||
kernel::setArg(index, (void*)buffer->hst()->data);
|
if(buffer)
|
||||||
|
kernel::setArg(index, (void*)buffer->hst()->data);
|
||||||
|
else
|
||||||
|
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<void *> &host_kernel::params(){
|
const std::vector<void *> &host_kernel::params(){
|
||||||
@@ -106,7 +109,10 @@ void ocl_kernel::setArg(unsigned int index, std::size_t size, void* ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) {
|
void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) {
|
||||||
check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl()));
|
if(buffer)
|
||||||
|
check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl()));
|
||||||
|
else
|
||||||
|
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -130,8 +136,12 @@ void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
|||||||
cu_params_[index] = cu_params_store_[index].get();
|
cu_params_[index] = cu_params_store_[index].get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void cu_kernel::setArg(unsigned int index, driver::buffer* data)
|
void cu_kernel::setArg(unsigned int index, driver::buffer* data){
|
||||||
{ return kernel::setArg(index, *data->cu());}
|
if(data)
|
||||||
|
kernel::setArg(index, *data->cu());
|
||||||
|
else
|
||||||
|
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||||
|
}
|
||||||
|
|
||||||
void* const* cu_kernel::cu_params() const
|
void* const* cu_kernel::cu_params() const
|
||||||
{ return cu_params_.data(); }
|
{ return cu_params_.data(); }
|
||||||
|
Reference in New Issue
Block a user