diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 93f42b94e..70555fd0e 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -10,12 +10,12 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); - triton::dnn::conv::type ty = triton::dnn::conv::FPROP; + triton::dnn::conv::type ty = triton::dnn::conv::BPROP; // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 56, W = 56; int32_t NC = 32, T = 1, R = 3, S = 3; - int32_t pad_d = 0, pad_h = 1, pad_w = 1; + int32_t pad_d = 0, pad_h = 0, pad_w = 0; triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty); // convolution configuration std::vector hc(configuration.c_size()); diff --git a/examples/python/pytorch/bench.py b/examples/python/pytorch/bench.py new file mode 100644 index 000000000..2c8c304b5 --- /dev/null +++ b/examples/python/pytorch/bench.py @@ -0,0 +1,117 @@ +import argparse +import triton +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +torch.manual_seed(0) + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 20, (5,5)) + self.conv2 = nn.Conv2d(20, 50, (5,5)) + self.fc1 = nn.Linear(4*4*50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4*4*50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') + parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.batch_size, shuffle=True, **kwargs) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.test_batch_size, shuffle=True, **kwargs) + + + model = Net().to(device) + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(args, model, device, test_loader) + + if (args.save_model): + torch.save(model.state_dict(),"mnist_cnn.pt") + +if __name__ == '__main__': + main() diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index 577d23ee0..a3002e28a 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -86,7 +86,8 @@ torch::Tensor conv_common( torch::Tensor conv_fprop( const torch::Tensor data, - const torch::Tensor weight) { + const torch::Tensor weight, + int64_t pad_h, int64_t pad_w) { // Check CHECK_INPUT(data); CHECK_INPUT(weight); @@ -104,7 +105,7 @@ torch::Tensor conv_fprop( const int32_t NF = weight.size(3); // Configuration const int32_t stride_d = 1, stride_h = 1, stride_w = 1; - const int32_t pad_d = 0, pad_h = 1, pad_w = 1; + const int32_t pad_d = 0; // Check AT_CHECK(Ci == Cf, "Number of channels in data and weights must match"); return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight); @@ -112,7 +113,8 @@ torch::Tensor conv_fprop( torch::Tensor conv_bprop( const torch::Tensor derror, - const torch::Tensor weight){ + const torch::Tensor weight, + int64_t pad_h, int64_t pad_w){ // Check CHECK_INPUT(derror); CHECK_INPUT(weight); @@ -131,10 +133,12 @@ torch::Tensor conv_bprop( // Compute M, P, Q const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; const int32_t stride_d = 1, stride_h = 1, stride_w = 1; - const int32_t pad_d = 0, pad_h = 1, pad_w = 1; - const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d; - const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h; - const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w; + int32_t pad_d = 0; + const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d; + const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h; + const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w; + + // Check AT_CHECK(Ki == Kw, "Number of channels in error and weights must match"); return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight); @@ -142,17 +146,18 @@ torch::Tensor conv_bprop( torch::Tensor conv_wgrad( const torch::Tensor data, - const torch::Tensor derror + const torch::Tensor derror, + int64_t pad_h, int64_t pad_w ){ // Check CHECK_INPUT(data); CHECK_INPUT(derror); // Unpack data shapes - const int32_t Ba = derror.size(0); - const int32_t C = derror.size(1); + const int32_t Ba = data.size(0); + const int32_t C = data.size(1); const int32_t D = 1; - const int32_t H = derror.size(2); - const int32_t W = derror.size(3); + const int32_t H = data.size(2); + const int32_t W = data.size(3); // Unpack error shapes const int32_t Bb = derror.size(0); const int32_t K = derror.size(1); @@ -162,10 +167,12 @@ torch::Tensor conv_wgrad( // Compute M, P, Q const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; const int32_t stride_d = 1, stride_h = 1, stride_w = 1; - const int32_t pad_d = 0, pad_h = 1, pad_w = 1; - const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d; - const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h; - const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w; + const int32_t pad_d = 0; + const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d; + const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h; + const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w; + + // Check AT_CHECK(Ba == Bb, "Number of channels in error and weights must match"); return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror); diff --git a/examples/python/pytorch/main.py b/examples/python/pytorch/main.py deleted file mode 100644 index c4601fe0f..000000000 --- a/examples/python/pytorch/main.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import time -torch.manual_seed(0) - -class TritonConv(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, weight): - ctx.save_for_backward(input, weight) - output = torch.ops.triton.conv_fprop(input, weight) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = grad_weight = None - if ctx.needs_input_grad[0]: - grad_input = torch.ops.triton.conv_bprop(grad_output, weight) - if ctx.needs_input_grad[1]: - grad_weight = torch.ops.triton.conv_wgrad(input, grad_output) - return grad_input, grad_weight - - -torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so") - -x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) -w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True) -cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True) -y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) - -def run(x, w, conv): - y = conv(x, w) - loss = (y - y_target).norm(2) - loss.backward() - return loss, y.clone(), x.grad.clone(), w.grad.clone() - -ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply) -x.grad.zero_() -w.grad.zero_() -culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1)) - - -print((tty - cuy).norm(2)) -print((ttdx - cudx).norm(2)) -print((ttdw.permute(3,0,1,2) - cudw).norm(2)) -#print(ttdx) -#print(cudx) -#print(ttdw) -#print(cudw) -#print((ttdw.permute(3,0,1,2) - cudw).norm(2)) diff --git a/examples/python/pytorch/test.py b/examples/python/pytorch/test.py new file mode 100644 index 000000000..5086f8a5d --- /dev/null +++ b/examples/python/pytorch/test.py @@ -0,0 +1,22 @@ +import torch +import triton + +x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) +w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True) +cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True) +y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True) + +def run(x, w, conv): + y = conv(x, w) + loss = (y - y_target).norm(2) + loss.backward() + return loss, y.clone(), x.grad.clone(), w.grad.clone() + +ttyloss, tty, ttdx, ttdw = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, 0)) +x.grad.zero_() +w.grad.zero_() +culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=0)) + +print((tty - cuy).norm(2)) +print((ttdx - cudx).norm(2)) +print((ttdw.permute(3,0,1,2) - cudw).norm(2)) diff --git a/examples/python/pytorch/triton.py b/examples/python/pytorch/triton.py new file mode 100644 index 000000000..57e17d515 --- /dev/null +++ b/examples/python/pytorch/triton.py @@ -0,0 +1,46 @@ +import torch +import math + +torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so") + +class ConvFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, padding): + ctx.save_for_backward(input, weight) + ctx.padding = padding + output = torch.ops.triton.conv_fprop(input, weight, padding, padding) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + padding = ctx.padding + grad_input = grad_weight = None + if ctx.needs_input_grad[0]: + grad_input = torch.ops.triton.conv_bprop(grad_output, weight, padding, padding) + if ctx.needs_input_grad[1]: + grad_weight = torch.ops.triton.conv_wgrad(input, grad_output, padding, padding) + return grad_input, grad_weight, None + + +class Conv2d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding = 0): + super(Conv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = torch.nn.Parameter(torch.Tensor( + in_channels, kernel_size[0], kernel_size[1], out_channels)) + self.reset_parameters() + + def forward(self, input): + return ConvFunction.apply(input, self.weight, self.padding) + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h new file mode 100644 index 000000000..a114cca8c --- /dev/null +++ b/include/triton/runtime/jit.h @@ -0,0 +1,116 @@ +#ifndef TDL_INCLUDE_JIT_H +#define TDL_INCLUDE_JIT_H + +#include +#include +#include "llvm/IR/LLVMContext.h" +#include "triton/ir/context.h" +#include "triton/ir/print.h" +#include "triton/driver/module.h" +#include "triton/driver/kernel.h" +#include "triton/codegen/selection.h" +#include "triton/codegen/tune.h" +#include "triton/codegen/optimize_dot.h" +#include "triton/codegen/optimize_cse.h" +#include "triton/codegen/optimize_trans.h" +#include "triton/codegen/shmem_allocation.h" +#include "triton/codegen/shmem_liveness.h" +#include "triton/codegen/shmem_info.h" +#include "triton/codegen/shmem_barriers.h" +#include "triton/codegen/target.h" +#include "triton/codegen/vectorize.h" +#include + +namespace llvm { + class Module; +} + +namespace triton { + +namespace codegen{ +class tune; +} + +namespace ir { +class module; +class context; +class metaparameter; +} + +class jit { +public: + struct launch_information{ + std::vector global_range_size; + unsigned num_threads; + }; + typedef std::function benchmark_t; + + struct passes_wrapper { + passes_wrapper(codegen::target* target) + : shmem_liveness(&shmem_info), + shmem_allocation(&shmem_liveness, &shmem_info), + shmem_barriers(&shmem_allocation, &shmem_info), + vectorize(&tune), + selection(&shmem_allocation, &tune, &shmem_info, target), + optimize_dot(&tune), + optimize_cse(), + optimize_trans(), + target_(target) { } + + void target_independent(ir::module &module) { + optimize_dot.run(module); + optimize_trans.run(module); + } + + void target_dependent(ir::module &module) { + if(target_->is_gpu()){ + shmem_info.run(module); + shmem_liveness.run(module); + shmem_allocation.run(); + shmem_barriers.run(module); + } + vectorize.run(module); + } + + codegen::tune tune; + codegen::shmem_info shmem_info; + codegen::shmem_liveness shmem_liveness; + codegen::shmem_allocation shmem_allocation; + codegen::shmem_barriers shmem_barriers; + codegen::vectorize vectorize; + codegen::selection selection; + codegen::optimize_dot optimize_dot; + codegen::optimize_cse optimize_cse; + codegen::optimize_trans optimize_trans; + codegen::target* target_; + }; + +private: + std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); + std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes); + std::unique_ptr make_triton_module(const char* name, const char* src); + +public: + jit(driver::context* context); + ~jit(); + void autotune(const char* name, const char* src, benchmark_t benchmark); + void add_module(ir::module &module, const std::vector& params = {}); + void add_module(const char* name, const char* src, const std::vector& params = {}); + driver::kernel* get_function(const char* name); + launch_information get_launch_info(const char* name); + unsigned get_int(const char* name); + +private: + std::map modules_; + driver::context* driver_context_; + llvm::LLVMContext llvm_context_; + ir::context triton_context_; + std::map launch_info_map_; + std::map global_ints_; + std::shared_ptr target_; +}; + + +} + +#endif diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 2c551241c..621f4f14a 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -24,10 +24,13 @@ conv::conv(int B, int NC, shapes_c_ = {NB_, NF_, CD_, CH_, CW_}; // swap a and c for bprop if(ty_ == BPROP){ - pad_d_ = (CD_ - AD_ + BD_ - 1) / 2; - pad_h_ = (CH_ - AH_ + BH_ - 1) / 2; - pad_w_ = (CW_ - AW_ + BW_ - 1) / 2; + std::swap(AD_, CD_); + std::swap(AH_, CH_); + std::swap(AW_, CW_); shapes_a_.swap(shapes_c_); + pad_d_ = (CD_*stride_d_ - AD_*upsample_d_ + BD_ - 1 - stride_d_ + 1)/2; + pad_h_ = (CH_*stride_h_ - AH_*upsample_h_ + BH_ - 1 - stride_h_ + 1)/2; + pad_w_ = (CW_*stride_w_ - AW_*upsample_w_ + BW_ - 1 - stride_w_ + 1)/2; } // swap b and c for wgrad if(ty_ == WGRAD){ diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp new file mode 100644 index 000000000..58d3aef73 --- /dev/null +++ b/lib/runtime/jit.cpp @@ -0,0 +1,209 @@ +#include +#include "triton/ast/ast.h" +#include "triton/codegen/target.h" +#include "triton/ir/context.h" +#include "triton/ir/context_impl.h" +#include "triton/driver/device.h" +#include "triton/driver/error.h" +#include "triton/runtime/jit.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/Analysis/LoopPass.h" + +typedef struct yy_buffer_state * YY_BUFFER_STATE; +extern int yyparse(); +extern YY_BUFFER_STATE yy_scan_string(const char * str); +extern void yy_delete_buffer(YY_BUFFER_STATE buffer); +using triton::ast::translation_unit; +extern translation_unit *ast_root; + +namespace triton { + +void loop_nest(std::vector const & ranges, std::function const &)> const & f){ + size_t D = ranges.size(); + std::vector values(D, 0); + // Start with innermost loop + size_t i = D - 1; + while(true){ + //Execute function + f(values); + //Increment counters + while(values[i]++ == ranges[i] - 1){ + if(i == 0) + return; + values[i--] = 0; + } + i = D - 1; + } +} + +template +void loop_nest(std::vector> const & iterates, std::function)> const & f){ + //Ranges to iterate over + std::vector ranges; + for(auto const & x: iterates) + ranges.push_back(x.size()); + //Proxy function + auto proxy = [&](std::vector const & idx){ + std::vector x(iterates.size()); + for(size_t i = 0; i < x.size(); ++i) + x[i] = iterates[i][idx[i]]; + f(x); + }; + //Iterate + loop_nest(ranges, proxy); +} + + + + +std::unique_ptr jit::make_llvm_module(ir::module &module, passes_wrapper &passes) { + llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_); + passes.selection.run(module, *result); + // launch information + launch_information& info = launch_info_map_[result->getName()]; + info.global_range_size.clear(); + for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) + info.global_range_size.push_back(passes.tune.get_global_range_size(i)); + info.num_threads = passes.tune.get_num_threads(); + return std::unique_ptr(result); +} + +std::unique_ptr jit::make_triton_module(const char *name, const char *src) { + // create AST from Triton-C source + YY_BUFFER_STATE buffer = yy_scan_string(src); + yyparse(); + yy_delete_buffer(buffer); + translation_unit *program = ast_root; + // create Triton-IR from AST + ir::module* module = new ir::module(name, triton_context_); + program->codegen(module); + return std::unique_ptr(module); +} + + +jit::jit(driver::context *context): driver_context_(context), + target_(context->device()->make_target()) { } + +jit::~jit(){ } + +void jit::autotune(const char *name, const char *src, benchmark_t benchmark) { + // find metaparameters + auto ptt_module = make_triton_module(name, src); + ir::module &tt_module = *ptt_module; + // set parameters + passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); + passes.tune.run(tt_module); + auto mps = passes.tune.get_params(tt_module); + // create parameter ranges + std::vector> ranges; + for(ir::metaparameter *mp: mps) + ranges.push_back(mp->get_space()); +// std::cout << ranges.size() << std::endl; + // iterate over parameters + unsigned i; + double best = 0; + loop_nest(ranges, [&](const std::vector params){ + std::map> errors; + i = 0; + for(ir::metaparameter *mp: mps) + mp->set_value(params[i++]); + passes.target_independent(tt_module); + passes.tune.init(tt_module); + if(!passes.tune.check_constraints(errors)) + return; + // Deep copy of the module and tuner + auto ptt_module = make_triton_module(name, src); + ir::module &tt_module = *ptt_module; + passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); + passes.tune.run(tt_module); + i = 0; + for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ + mp->set_value(params[i++]); + } + passes.tune.init(tt_module); + passes.target_dependent(tt_module); + driver::device* device = driver_context_->device(); + if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory()) + return; + if(passes.tune.get_num_threads() > device->max_threads_per_block()) + return; + // Compile + auto ll_module = make_llvm_module(tt_module, passes); + std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); + std::unique_ptr kernel(driver::kernel::create(module.get(), name)); + launch_information info = launch_info_map_.at(name); + for(unsigned p: params) + std::cout << p << " " << std::flush; + // add globals + for(auto x: tt_module.globals()) + global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); + modules_.insert({name, module.get()}); + double perf; + perf = benchmark(kernel.get(), info); + best = std::max(perf, best); + std::cout << perf << " [ " << best << " ] " << std::endl; + modules_.erase(name); + }); +} + +void jit::add_module(ir::module &tt_module, const std::vector ¶ms) { + // set parameters + passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); + passes.tune.run(tt_module); + unsigned i = 0; + for(ir::metaparameter* mp: passes.tune.get_params(tt_module)) + mp->set_value(params[i++]); + passes.tune.init(tt_module); + passes.target_dependent(tt_module); + // check constraints + std::map> errors; + passes.tune.check_constraints(errors); + for(auto x: errors){ + std::cout << x.first << std::endl; + for(auto str: x.second) + std::cout << str << std::endl; + } + if(errors.size()) + throw std::runtime_error("invalid parameters"); + // triton module -> llvm module + auto ll_module = make_llvm_module(tt_module, passes); + // llvm module -> machine code + std::string name = tt_module.get_name(); + modules_.insert({name, driver::module::create(driver_context_, &*ll_module)}); + // add globals + for(auto x: tt_module.globals()) + global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); +} + +void jit::add_module(const char *name, const char *src, const std::vector ¶ms) { + auto ptt_module = make_triton_module(name, src); + add_module(*ptt_module, params); +} + +driver::kernel *jit::get_function(const char *name) { + return driver::kernel::create(modules_.at(name), name); +} + +jit::launch_information jit::get_launch_info(const char *name) { + return launch_info_map_.at(name); +} + +unsigned jit::get_int(const char *name){ + return global_ints_.at(name); +} + +}