diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 025ca0d4b..93f42b94e 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -1,7 +1,7 @@ #include #include #include "common.hpp" -#include "triton/jit.h" +#include "triton/runtime/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/dnn/conv.h" @@ -10,11 +10,11 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); - triton::dnn::conv::type ty = triton::dnn::conv::WGRAD; + triton::dnn::conv::type ty = triton::dnn::conv::FPROP; // initialization - int32_t B = 32, NF = 128; + int32_t B = 4, NF = 32; int32_t D = 1, H = 56, W = 56; - int32_t NC = 128, T = 1, R = 3, S = 3; + int32_t NC = 32, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 1, pad_w = 1; triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty); // convolution configuration @@ -45,7 +45,7 @@ int main() { unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; std::array grid = configuration.get_grid(TM, TN); - configuration.init(stream, jit); + configuration.init(stream, (triton::driver::cu_module*)kernel->module()); stream->synchronize(); configuration.set_arg(kernel, da, db, dc); stream->enqueue(kernel, grid, {nthreads, 1, 1}); @@ -55,7 +55,7 @@ int main() { return configuration.get_nflops() / ts * 1e-3; }; std::string src = configuration.src(); - jit.autotune("conv", src.c_str(), benchmark); +// jit.autotune("conv", src.c_str(), benchmark); jit.add_module("conv", src.c_str(), configuration.default_params()); triton::driver::kernel* kernel = jit.get_function("conv"); triton::jit::launch_information info = jit.get_launch_info("conv"); diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 3dde373ef..2beee1c8d 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -1,7 +1,7 @@ #include #include #include "common.hpp" -#include "triton/jit.h" +#include "triton/runtime/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/dnn/gemm.h" diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 026cdfaea..4391f775b 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -1,7 +1,7 @@ #include #include #include "common.hpp" -#include "triton/jit.h" +#include "triton/runtime/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" @@ -158,8 +158,8 @@ int main() { unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; // initialize constant memory - triton::driver::buffer* delta = jit.get_buffer("delta"); - triton::driver::buffer* masks = jit.get_buffer("masks"); + triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta"); + triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks"); stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); stream->synchronize(); diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index f8636c482..577d23ee0 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -2,7 +2,7 @@ #include #include "ATen/cuda/CUDAContext.h" #include -#include "triton/jit.h" +#include "triton/runtime/jit.h" #include "triton/driver/stream.h" #include "triton/dnn/conv.h" @@ -10,6 +10,16 @@ #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +typedef std::tuple conv_key_t; + +static std::map> m_stream; +static std::map> m_jit; +static std::map> m_config; + torch::Tensor conv_common( int32_t B, int32_t C, int32_t D, int32_t H, int32_t W, int32_t T, int32_t R, int32_t S, int32_t NF, @@ -18,41 +28,59 @@ torch::Tensor conv_common( triton::dnn::conv::type ty, torch::Tensor torcha, torch::Tensor torchb ) { - // Configuration - triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty); + // Wrap CUDA handles + c10::DeviceIndex device = torcha.storage().device().index(); + // Get stream + CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream(); + triton::driver::stream* stream; + if(m_stream.find(custream) == m_stream.end()) + stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get(); + else + stream = m_stream.at(custream).get(); + // Get context + triton::driver::context* ctx = stream->context(); + // Get configuration + conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty}; + triton::dnn::conv* configuration; + if(m_config.find(key) == m_config.end()) + configuration = m_config.emplace(key, new triton::dnn::conv( + B, C, D, H, W, T, R, S, NF, + stride_d, stride_h, stride_w, + pad_d, pad_h, pad_w, ty)).first->second.get(); + else + configuration = m_config.at(key).get(); + // Get JIT + triton::jit* jit; + if(m_jit.find(key) == m_jit.end()){ + jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get(); + std::string src = configuration->src(); + jit->add_module("conv", src.c_str(), configuration->default_params()); + } + else + jit = m_jit.at(key).get(); + // Get memory + triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); + triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); // Allocate output - std::vector c_shapes = configuration.c_shapes(); + std::vector c_shapes = configuration->c_shapes(); torch::Tensor torchc; if(ty == triton::dnn::conv::WGRAD) torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda(); else torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda(); - // Wrap CUDA handles - c10::DeviceIndex device = torchc.storage().device().index(); - triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false); - triton::driver::stream* stream = &sstream; - triton::driver::context* ctx = stream->context(); - triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); - triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); - stream->synchronize(); - // Create JIT - triton::jit jit(ctx); - std::string src = configuration.src(); - jit.add_module("conv", src.c_str(), configuration.default_params()); - triton::driver::kernel* kernel = jit.get_function("conv"); - triton::jit::launch_information info = jit.get_launch_info("conv"); + // Add module to JIT + triton::driver::kernel* kernel = jit->get_function("conv"); + triton::jit::launch_information info = jit->get_launch_info("conv"); // launch info unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; // launch info - configuration.init(stream, jit); + configuration->init(stream, (triton::driver::cu_module*)kernel->module()); unsigned nthreads = info.num_threads; - std::array grid = configuration.get_grid(TM, TN); - configuration.set_arg(kernel, &a, &b, &c); - stream->synchronize(); + std::array grid = configuration->get_grid(TM, TN); + configuration->set_arg(kernel, &a, &b, &c); stream->enqueue(kernel, grid, {nthreads, 1, 1}); - stream->synchronize(); return torchc; } diff --git a/examples/python/pytorch/main.py b/examples/python/pytorch/main.py index c0568f8b4..c4601fe0f 100644 --- a/examples/python/pytorch/main.py +++ b/examples/python/pytorch/main.py @@ -1,4 +1,5 @@ import torch +import time torch.manual_seed(0) class TritonConv(torch.autograd.Function): @@ -14,9 +15,9 @@ class TritonConv(torch.autograd.Function): input, weight = ctx.saved_tensors grad_input = grad_weight = None if ctx.needs_input_grad[0]: - grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight) + grad_input = torch.ops.triton.conv_bprop(grad_output, weight) if ctx.needs_input_grad[1]: - grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous()) + grad_weight = torch.ops.triton.conv_wgrad(input, grad_output) return grad_input, grad_weight @@ -38,6 +39,7 @@ x.grad.zero_() w.grad.zero_() culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1)) + print((tty - cuy).norm(2)) print((ttdx - cudx).norm(2)) print((ttdw.permute(3,0,1,2) - cudw).norm(2)) diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 20b430187..d01007aa6 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -4,7 +4,6 @@ #include #include "triton/driver/stream.h" #include "triton/driver/kernel.h" -#include "triton/jit.h" namespace triton{ namespace dnn{ @@ -34,7 +33,7 @@ public: // initialize void build_deltas(); void build_masks(); - void init(driver::stream *stream, triton::jit &jit); + void init(driver::stream *stream, driver::cu_module *module); std::array get_grid(size_t TM, size_t TN); void set_arg(driver::kernel *kernel, driver::buffer *a, driver::buffer *b, driver::buffer *c); @@ -44,7 +43,144 @@ public: std::vector default_params(); // source - std::string src(); + std::string src(){ + bool is_wgrad = ty_ == WGRAD; + std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]"; + std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]"; + std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]"; + std::string ldb0 = b_trans_ ? "*ldb_s" : ""; + std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c"; + std::string useb = b_trans_ ? "trans(b)" : "b"; + std::string flipr = b_trans_ ? "" : "BH - 1 -"; + std::string flips = b_trans_ ? "" : "BW - 1 -"; + std::string ax = b_trans_ ? "crs" : "rsc"; + std::vector redax; + if(b_trans_) + redax = {"C", "BH", "BW"}; + else + redax = {"BH", "BW", "N"}; + std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0; + std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; + std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; + std::string masks_mem = is_mask_cst_? "__constant__" : ""; + + std::string res = + R"( + const tunable int32 TM = {16, 32, 64}; + const tunable int32 TN = {16, 32, 64}; + const tunable int32 TK = {8}; + )"; + if(is_a_deltas_cst) + res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n"; + if(is_wgrad && is_b_deltas_cst_) + res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n"; + if(is_mask_cst_) + res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; + res += R"( + + void conv(read_only restrict fp32 *a, + read_only restrict fp32 *b, + fp32 *c, + int32 M, int32 N, int32 K, + int32 AH, int32 AW, + int32 BH, int32 BW, + int32 CH, int32 CW, + int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, + int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, + int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, + int32 pad_h, int32 pad_w)"; + if(!is_a_deltas_cst) + res += ", int32* delta"; + if(is_wgrad && !is_b_deltas_cst_) + res += ", int32* b_delta"; + if(!is_mask_cst_) + res += ", int32* masks"; + res += R"(){ + int32 rxa[TM] = get_global_range[TM](0); + int32 rb0[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + int32 ldlut = )" + std::to_string(Fs_) + R"(; + int32 rabh[TM] = rxa / CW; + int32 raw[TM] = rxa % CW - pad_w; + int32 rab[TM] = rabh / CH; + int32 rah[TM] = rabh % CH - pad_h; + int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; + int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; + int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; + int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; + int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; + rar = )" + flipr + R"( rar; + ras = )" + flips + R"( ras; + int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; + if(ty_ == WGRAD){ + res += R"( + int32 rbcr[TK] = rkb / BW; + int32 rbs[TK] = rkb % BW; + int32 rbc[TK] = rbcr / BH; + int32 rbr[TK] = rbcr % BH; + int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s; + )" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb; + int32 db[TK] = *pdb;)"; + } + else{ + res += R"( + int32 rb1[TK] = rkb;)"; + } + res += R"( + fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; + )" + a_delta_mem + R"( int32* pincd[TK] = delta + rka; + )" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka; + int32 d[TK] = *pd; + int32 incd[TK] = *pincd; + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); + )" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1); + )" + a_delta_mem + R"( int32* pincm[TM] = delta; + int32 incm[TM] = *pincm; + int32 checka0[TM] = *pm; + int32 checka1[TK] = 1 << rka; + int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + fp32 a[TM, TK] = checka ? *pa : 0; + fp32 b)" + BS + R"( = *pb; + for(int32 k = K; k > 0; k = k - TK){ + C = dot(a, )" + useb + R"(, C); + pa = pa + d[newaxis, :]; + pb = pb + )" + inc_pb + R"(; + b = *pb; + pd = pd + incd;)"; + if(ty_ == WGRAD){ + res += R"( + pdb = pdb + TK; + db = *pdb;)"; + } + res += R"( + pincd = pincd + incd; + d = *pd; + incd = *pincd; + pm = pm + incm; + pincm = pincm + incm; + incm = *pincm; + checka0 = *pm; + checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + checka = checka && (k > TK); + a = checka ? *pa : 0; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 rc1[TN] = get_global_range[TN](1); + int32 rcn[TM] = rxc / (CH*CW); + int32 rcpq[TM] = rxc % (CH*CW); + int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q; + fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = rc1 < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; + })"; + return res; + } // cpu check template diff --git a/include/triton/jit.h b/include/triton/jit.h deleted file mode 100644 index a3e554c67..000000000 --- a/include/triton/jit.h +++ /dev/null @@ -1,117 +0,0 @@ -#ifndef TDL_INCLUDE_JIT_H -#define TDL_INCLUDE_JIT_H - -#include -#include -#include "llvm/IR/LLVMContext.h" -#include "triton/ir/context.h" -#include "triton/ir/print.h" -#include "triton/driver/module.h" -#include "triton/driver/kernel.h" -#include "triton/codegen/selection.h" -#include "triton/codegen/tune.h" -#include "triton/codegen/optimize_dot.h" -#include "triton/codegen/optimize_cse.h" -#include "triton/codegen/optimize_trans.h" -#include "triton/codegen/shmem_allocation.h" -#include "triton/codegen/shmem_liveness.h" -#include "triton/codegen/shmem_info.h" -#include "triton/codegen/shmem_barriers.h" -#include "triton/codegen/target.h" -#include "triton/codegen/vectorize.h" -#include - -namespace llvm { - class Module; -} - -namespace triton { - -namespace codegen{ -class tune; -} - -namespace ir { -class module; -class context; -class metaparameter; -} - -class jit { -public: - struct launch_information{ - std::vector global_range_size; - unsigned num_threads; - }; - typedef std::function benchmark_t; - - struct passes_wrapper { - passes_wrapper(codegen::target* target) - : shmem_liveness(&shmem_info), - shmem_allocation(&shmem_liveness, &shmem_info), - shmem_barriers(&shmem_allocation, &shmem_info), - vectorize(&tune), - selection(&shmem_allocation, &tune, &shmem_info, target), - optimize_dot(&tune), - optimize_cse(), - optimize_trans(), - target_(target) { } - - void target_independent(ir::module &module) { - optimize_dot.run(module); - optimize_trans.run(module); - } - - void target_dependent(ir::module &module) { - if(target_->is_gpu()){ - shmem_info.run(module); - shmem_liveness.run(module); - shmem_allocation.run(); - shmem_barriers.run(module); - } - vectorize.run(module); - } - - codegen::tune tune; - codegen::shmem_info shmem_info; - codegen::shmem_liveness shmem_liveness; - codegen::shmem_allocation shmem_allocation; - codegen::shmem_barriers shmem_barriers; - codegen::vectorize vectorize; - codegen::selection selection; - codegen::optimize_dot optimize_dot; - codegen::optimize_cse optimize_cse; - codegen::optimize_trans optimize_trans; - codegen::target* target_; - }; - -private: - std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); - std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes); - std::unique_ptr make_triton_module(const char* name, const char* src); - -public: - jit(driver::context* context); - ~jit(); - void autotune(const char* name, const char* src, benchmark_t benchmark); - void add_module(ir::module &module, const std::vector& params = {}); - void add_module(const char* name, const char* src, const std::vector& params = {}); - driver::kernel* get_function(const char* name); - launch_information get_launch_info(const char* name); - unsigned get_int(const char* name); - driver::buffer* get_buffer(const char* name); - -private: - std::vector modules_; - driver::context* driver_context_; - llvm::LLVMContext llvm_context_; - ir::context triton_context_; - std::map launch_info_map_; - std::map global_ints_; - std::unique_ptr target_; -}; - - -} - -#endif diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index be47b95c5..2c551241c 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -207,7 +207,7 @@ std::array conv::get_grid(size_t TM, size_t TN) size_t conv::get_nflops() { return 2.*M_*N_*K_; } -void conv::init(driver::stream *stream, triton::jit &jit) { +void conv::init(driver::stream *stream, triton::driver::cu_module* module) { auto init_lut = [&](bool is_cst, const char *name, std::vector host) -> triton::driver::buffer*{ if(host.empty()) return nullptr; @@ -215,7 +215,7 @@ void conv::init(driver::stream *stream, triton::jit &jit) { // get buffer triton::driver::buffer* buffer; if(is_cst) - buffer = jit.get_buffer(name); + buffer = module->symbol(name); else buffer = triton::driver::buffer::create(stream->context(), nbytes); // copy @@ -306,145 +306,6 @@ std::vector conv::default_params() { } -std::string conv::src() { - bool is_wgrad = ty_ == WGRAD; - std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]"; - std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]"; - std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]"; - std::string ldb0 = b_trans_ ? "*ldb_s" : ""; - std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c"; - std::string useb = b_trans_ ? "trans(b)" : "b"; - std::string flipr = b_trans_ ? "" : "BH - 1 -"; - std::string flips = b_trans_ ? "" : "BW - 1 -"; - std::string ax = b_trans_ ? "crs" : "rsc"; - std::vector redax; - if(b_trans_) - redax = {"C", "BH", "BW"}; - else - redax = {"BH", "BW", "N"}; - std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0; - std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : ""; - std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : ""; - std::string masks_mem = is_mask_cst_? "__constant__" : ""; - - std::string res = - R"( -const tunable int32 TM = {16, 32, 64}; -const tunable int32 TN = {16, 32, 64}; -const tunable int32 TK = {8}; -)"; -if(is_a_deltas_cst) - res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n"; -if(is_wgrad && is_b_deltas_cst_) - res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n"; -if(is_mask_cst_) - res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; -res += R"( - - void conv(read_only restrict fp32 *a, - read_only restrict fp32 *b, - fp32 *c, - int32 M, int32 N, int32 K, - int32 AH, int32 AW, - int32 BH, int32 BW, - int32 CH, int32 CW, - int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, - int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, - int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, - int32 pad_h, int32 pad_w)"; -if(!is_a_deltas_cst) - res += ", int32* delta"; -if(is_wgrad && !is_b_deltas_cst_) - res += ", int32* b_delta"; -if(!is_mask_cst_) - res += ", int32* masks"; - res += R"(){ - int32 rxa[TM] = get_global_range[TM](0); - int32 rb0[TN] = get_global_range[TN](1); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 C[TM, TN] = 0; - int32 ldlut = )" + std::to_string(Fs_) + R"(; - int32 rabh[TM] = rxa / CW; - int32 raw[TM] = rxa % CW - pad_w; - int32 rab[TM] = rabh / CH; - int32 rah[TM] = rabh % CH - pad_h; - int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; - int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; - int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; - int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; - rar = )" + flipr + R"( rar; - ras = )" + flips + R"( ras; - int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; -if(ty_ == WGRAD){ - res += R"( - int32 rbcr[TK] = rkb / BW; - int32 rbs[TK] = rkb % BW; - int32 rbc[TK] = rbcr / BH; - int32 rbr[TK] = rbcr % BH; - int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s; - )" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb; - int32 db[TK] = *pdb;)"; -} -else{ -res += R"( - int32 rb1[TK] = rkb;)"; -} -res += R"( - fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(; - )" + a_delta_mem + R"( int32* pincd[TK] = delta + rka; - )" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka; - int32 d[TK] = *pd; - int32 incd[TK] = *pincd; - int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); - int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); - )" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1); - )" + a_delta_mem + R"( int32* pincm[TM] = delta; - int32 incm[TM] = *pincm; - int32 checka0[TM] = *pm; - int32 checka1[TK] = 1 << rka; - int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; - fp32 a[TM, TK] = checka ? *pa : 0; - fp32 b)" + BS + R"( = *pb; - for(int32 k = K; k > 0; k = k - TK){ - C = dot(a, )" + useb + R"(, C); - pa = pa + d[newaxis, :]; - pb = pb + )" + inc_pb + R"(; - b = *pb; - pd = pd + incd;)"; -if(ty_ == WGRAD){ - res += R"( - pdb = pdb + TK; - db = *pdb;)"; -} - res += R"( - pincd = pincd + incd; - d = *pd; - incd = *pincd; - pm = pm + incm; - pincm = pincm + incm; - incm = *pincm; - checka0 = *pm; - checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; - checka = checka && (k > TK); - a = checka ? *pa : 0; - } - int32 rxc[TM] = get_global_range[TM](0); - int32 rc1[TN] = get_global_range[TN](1); - int32 rcn[TM] = rxc / (CH*CW); - int32 rcpq[TM] = rxc % (CH*CW); - int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q; - fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = rc1 < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = C; -})"; - return res; -} - template void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) { diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 1df832aeb..3f595b318 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -100,10 +100,10 @@ module* module::create(driver::context* ctx, llvm::Module *src) { } void module::compile_llvm_module(llvm::Module* module, const std::string& triple, - const std::string &proc, std::string layout, - llvm::SmallVectorImpl &buffer, - const std::string& features, - file_type_t ft) { + const std::string &proc, std::string layout, + llvm::SmallVectorImpl &buffer, + const std::string& features, + file_type_t ft) { init_llvm(); // debug // llvm::legacy::PassManager pm; diff --git a/lib/jit.cpp b/lib/jit.cpp deleted file mode 100644 index 059f96a00..000000000 --- a/lib/jit.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include "triton/jit.h" -#include -#include "triton/ast/ast.h" -#include "triton/codegen/target.h" -#include "triton/ir/context.h" -#include "triton/ir/context_impl.h" -#include "triton/driver/device.h" -#include "triton/driver/error.h" -#include "llvm/IR/IRPrintingPasses.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Support/TargetRegistry.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetOptions.h" -#include "llvm/CodeGen/TargetPassConfig.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/Transforms/Scalar/EarlyCSE.h" -#include "llvm/Analysis/LoopPass.h" - -typedef struct yy_buffer_state * YY_BUFFER_STATE; -extern int yyparse(); -extern YY_BUFFER_STATE yy_scan_string(const char * str); -extern void yy_delete_buffer(YY_BUFFER_STATE buffer); -using triton::ast::translation_unit; -extern translation_unit *ast_root; - -namespace triton { - -void loop_nest(std::vector const & ranges, std::function const &)> const & f){ - size_t D = ranges.size(); - std::vector values(D, 0); - // Start with innermost loop - size_t i = D - 1; - while(true){ - //Execute function - f(values); - //Increment counters - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; - } -} - -template -void loop_nest(std::vector> const & iterates, std::function)> const & f){ - //Ranges to iterate over - std::vector ranges; - for(auto const & x: iterates) - ranges.push_back(x.size()); - //Proxy function - auto proxy = [&](std::vector const & idx){ - std::vector x(iterates.size()); - for(size_t i = 0; i < x.size(); ++i) - x[i] = iterates[i][idx[i]]; - f(x); - }; - //Iterate - loop_nest(ranges, proxy); -} - - - - -std::unique_ptr jit::make_llvm_module(ir::module &module, passes_wrapper &passes) { - llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_); - passes.selection.run(module, *result); - // launch information - launch_information& info = launch_info_map_[result->getName()]; - info.global_range_size.clear(); - for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) - info.global_range_size.push_back(passes.tune.get_global_range_size(i)); - info.num_threads = passes.tune.get_num_threads(); - return std::unique_ptr(result); -} - -std::unique_ptr jit::make_triton_module(const char *name, const char *src) { - // create AST from Triton-C source - YY_BUFFER_STATE buffer = yy_scan_string(src); - yyparse(); - yy_delete_buffer(buffer); - translation_unit *program = ast_root; - // create Triton-IR from AST - ir::module* module = new ir::module(name, triton_context_); - program->codegen(module); - return std::unique_ptr(module); -} - - -jit::jit(driver::context *context): driver_context_(context), - target_(context->device()->make_target()) { } - -jit::~jit(){ } - -void jit::autotune(const char *name, const char *src, benchmark_t benchmark) { - // find metaparameters - auto ptt_module = make_triton_module(name, src); - ir::module &tt_module = *ptt_module; - // set parameters - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - auto mps = passes.tune.get_params(tt_module); - // create parameter ranges - std::vector> ranges; - for(ir::metaparameter *mp: mps) - ranges.push_back(mp->get_space()); -// std::cout << ranges.size() << std::endl; - // iterate over parameters - unsigned i; - double best = 0; - loop_nest(ranges, [&](const std::vector params){ - std::map> errors; - i = 0; - for(ir::metaparameter *mp: mps) - mp->set_value(params[i++]); - passes.target_independent(tt_module); - passes.tune.init(tt_module); - if(!passes.tune.check_constraints(errors)) - return; - // Deep copy of the module and tuner - auto ptt_module = make_triton_module(name, src); - ir::module &tt_module = *ptt_module; - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - i = 0; - for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ - mp->set_value(params[i++]); - } - passes.tune.init(tt_module); - passes.target_dependent(tt_module); - driver::device* device = driver_context_->device(); - if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory()) - return; - if(passes.tune.get_num_threads() > device->max_threads_per_block()) - return; - // Compile - auto ll_module = make_llvm_module(tt_module, passes); - std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); - std::unique_ptr kernel(driver::kernel::create(module.get(), name)); - launch_information info = launch_info_map_.at(name); - for(unsigned p: params) - std::cout << p << " " << std::flush; - // add globals - for(auto x: tt_module.globals()) - global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); - modules_.push_back(module.get()); - double perf; - perf = benchmark(kernel.get(), info); - best = std::max(perf, best); - std::cout << perf << " [ " << best << " ] " << std::endl; - modules_.pop_back(); - }); -} - -void jit::add_module(ir::module &tt_module, const std::vector ¶ms) { - // set parameters - passes_wrapper passes(target_.get()); - passes.target_independent(tt_module); - passes.tune.run(tt_module); - unsigned i = 0; - for(ir::metaparameter* mp: passes.tune.get_params(tt_module)) - mp->set_value(params[i++]); - passes.tune.init(tt_module); - passes.target_dependent(tt_module); - // check constraints - std::map> errors; - passes.tune.check_constraints(errors); - for(auto x: errors){ - std::cout << x.first << std::endl; - for(auto str: x.second) - std::cout << str << std::endl; - } - if(errors.size()) - throw std::runtime_error("invalid parameters"); -// driver::device* device = driver_context_->device(); -// if(passes.allocation.get_allocated_size() > device->max_shared_memory()) -// throw std::runtime_error("invalid parameters"); - // triton module -> llvm module - auto ll_module = make_llvm_module(tt_module, passes); - // llvm module -> machine code - modules_.push_back(driver::module::create(driver_context_, &*ll_module)); - // add globals - for(auto x: tt_module.globals()) - global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); -} - -void jit::add_module(const char *name, const char *src, const std::vector ¶ms) { - auto ptt_module = make_triton_module(name, src); - add_module(*ptt_module, params); -} - -driver::kernel *jit::get_function(const char *name) { - return driver::kernel::create(modules_.front(), name); -} - -jit::launch_information jit::get_launch_info(const char *name) { - return launch_info_map_.at(name); -} - -unsigned jit::get_int(const char *name){ - return global_ints_.at(name); -} - -driver::buffer *jit::get_buffer(const char *name){ - driver::cu_module *mod = (driver::cu_module*)modules_.front(); - return mod->symbol(name); -} - -}