[examples/pytorch] Fixed issues in backward pass of conv
This commit is contained in:
@@ -10,12 +10,12 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 56, W = 56;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(configuration.c_size());
|
||||
|
117
examples/python/pytorch/bench.py
Normal file
117
examples/python/pytorch/bench.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import argparse
|
||||
import triton
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, (5,5))
|
||||
self.conv2 = nn.Conv2d(20, 50, (5,5))
|
||||
self.fc1 = nn.Linear(4*4*50, 500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4*4*50)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
def test(args, model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--save-model', action='store_true', default=False,
|
||||
help='For Saving the current Model')
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
model = Net().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(args, model, device, test_loader)
|
||||
|
||||
if (args.save_model):
|
||||
torch.save(model.state_dict(),"mnist_cnn.pt")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -86,7 +86,8 @@ torch::Tensor conv_common(
|
||||
|
||||
torch::Tensor conv_fprop(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -104,7 +105,7 @@ torch::Tensor conv_fprop(
|
||||
const int32_t NF = weight.size(3);
|
||||
// Configuration
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t pad_d = 0;
|
||||
// Check
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
||||
@@ -112,7 +113,8 @@ torch::Tensor conv_fprop(
|
||||
|
||||
torch::Tensor conv_bprop(
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor weight){
|
||||
const torch::Tensor weight,
|
||||
int64_t pad_h, int64_t pad_w){
|
||||
// Check
|
||||
CHECK_INPUT(derror);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -131,10 +133,12 @@ torch::Tensor conv_bprop(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
|
||||
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
|
||||
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
|
||||
int32_t pad_d = 0;
|
||||
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
|
||||
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
|
||||
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
||||
@@ -142,17 +146,18 @@ torch::Tensor conv_bprop(
|
||||
|
||||
torch::Tensor conv_wgrad(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor derror
|
||||
const torch::Tensor derror,
|
||||
int64_t pad_h, int64_t pad_w
|
||||
){
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(derror);
|
||||
// Unpack data shapes
|
||||
const int32_t Ba = derror.size(0);
|
||||
const int32_t C = derror.size(1);
|
||||
const int32_t Ba = data.size(0);
|
||||
const int32_t C = data.size(1);
|
||||
const int32_t D = 1;
|
||||
const int32_t H = derror.size(2);
|
||||
const int32_t W = derror.size(3);
|
||||
const int32_t H = data.size(2);
|
||||
const int32_t W = data.size(3);
|
||||
// Unpack error shapes
|
||||
const int32_t Bb = derror.size(0);
|
||||
const int32_t K = derror.size(1);
|
||||
@@ -162,10 +167,12 @@ torch::Tensor conv_wgrad(
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
|
||||
const int32_t pad_d = 0;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
|
||||
|
||||
|
||||
// Check
|
||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
||||
|
@@ -1,50 +0,0 @@
|
||||
import torch
|
||||
import time
|
||||
torch.manual_seed(0)
|
||||
|
||||
class TritonConv(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
output = torch.ops.triton.conv_fprop(input, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
loss = (y - y_target).norm(2)
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply)
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
|
||||
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
#print(ttdx)
|
||||
#print(cudx)
|
||||
#print(ttdw)
|
||||
#print(cudw)
|
||||
#print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
22
examples/python/pytorch/test.py
Normal file
22
examples/python/pytorch/test.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
loss = (y - y_target).norm(2)
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0))
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0))
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
46
examples/python/pytorch/triton.py
Normal file
46
examples/python/pytorch/triton.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
class ConvFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, padding):
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.padding = padding
|
||||
output = torch.ops.triton.conv_fprop(input, weight, padding, padding)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
padding = ctx.padding
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding)
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
|
||||
super(Conv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(
|
||||
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input):
|
||||
return ConvFunction.apply(input, self.weight, self.padding)
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
116
include/triton/runtime/jit.h
Normal file
116
include/triton/runtime/jit.h
Normal file
@@ -0,0 +1,116 @@
|
||||
#ifndef TDL_INCLUDE_JIT_H
|
||||
#define TDL_INCLUDE_JIT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
class tune;
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
class jit {
|
||||
public:
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
};
|
||||
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(codegen::target* target)
|
||||
: shmem_liveness(&shmem_info),
|
||||
shmem_allocation(&shmem_liveness, &shmem_info),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
codegen::shmem_info shmem_info;
|
||||
codegen::shmem_liveness shmem_liveness;
|
||||
codegen::shmem_allocation shmem_allocation;
|
||||
codegen::shmem_barriers shmem_barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
~jit();
|
||||
void autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const char* name);
|
||||
launch_information get_launch_info(const char* name);
|
||||
unsigned get_int(const char* name);
|
||||
|
||||
private:
|
||||
std::map<std::string, driver::module*> modules_;
|
||||
driver::context* driver_context_;
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::map<std::string, unsigned> global_ints_;
|
||||
std::shared_ptr<triton::codegen::target> target_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -24,10 +24,13 @@ conv::conv(int B, int NC,
|
||||
shapes_c_ = {NB_, NF_, CD_, CH_, CW_};
|
||||
// swap a and c for bprop
|
||||
if(ty_ == BPROP){
|
||||
pad_d_ = (CD_ - AD_ + BD_ - 1) / 2;
|
||||
pad_h_ = (CH_ - AH_ + BH_ - 1) / 2;
|
||||
pad_w_ = (CW_ - AW_ + BW_ - 1) / 2;
|
||||
std::swap(AD_, CD_);
|
||||
std::swap(AH_, CH_);
|
||||
std::swap(AW_, CW_);
|
||||
shapes_a_.swap(shapes_c_);
|
||||
pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2;
|
||||
pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2;
|
||||
pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2;
|
||||
}
|
||||
// swap b and c for wgrad
|
||||
if(ty_ == WGRAD){
|
||||
|
209
lib/runtime/jit.cpp
Normal file
209
lib/runtime/jit.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include <string>
|
||||
#include "triton/ast/ast.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
using triton::ast::translation_unit;
|
||||
extern translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
//Execute function
|
||||
f(values);
|
||||
//Increment counters
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
ranges.push_back(x.size());
|
||||
//Proxy function
|
||||
auto proxy = [&](std::vector<size_t> const & idx){
|
||||
std::vector<T> x(iterates.size());
|
||||
for(size_t i = 0; i < x.size(); ++i)
|
||||
x[i] = iterates[i][idx[i]];
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
loop_nest(ranges, proxy);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
launch_information& info = launch_info_map_[result->getName()];
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
info.num_threads = passes.tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module(name, triton_context_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context): driver_context_(context),
|
||||
target_(context->device()->make_target()) { }
|
||||
|
||||
jit::~jit(){ }
|
||||
|
||||
void jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// std::cout << ranges.size() << std::endl;
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
if(!passes.tune.check_constraints(errors))
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
launch_information info = launch_info_map_.at(name);
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.insert({name, module.get()});
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
modules_.erase(name);
|
||||
});
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
for(auto x: errors){
|
||||
std::cout << x.first << std::endl;
|
||||
for(auto str: x.second)
|
||||
std::cout << str << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
std::string name = tt_module.get_name();
|
||||
modules_.insert({name, driver::module::create(driver_context_, &*ll_module)});
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
driver::kernel *jit::get_function(const char *name) {
|
||||
return driver::kernel::create(modules_.at(name), name);
|
||||
}
|
||||
|
||||
jit::launch_information jit::get_launch_info(const char *name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
unsigned jit::get_int(const char *name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
}
|
Reference in New Issue
Block a user