From 5af7e5adac58d0234129d44c8a1088cd39fb8699 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 30 Jul 2019 20:02:16 -0700 Subject: [PATCH] Made sure it works for FP16 --- examples/python/tensorflow/batchnorm.cpp | 28 ++++++++--------- examples/python/tensorflow/blocksparse.cpp | 27 ++++++++-------- examples/python/tensorflow/conv.cpp | 6 ++-- examples/python/tensorflow/dot.cpp | 6 ++-- examples/python/tensorflow/shift.cpp | 6 ++-- include/triton/dnn/base.h | 4 +-- include/triton/dnn/batchnorm.h | 4 +-- include/triton/dnn/blocksparse/dot.h | 9 ++++-- include/triton/dnn/conv.h | 2 +- include/triton/dnn/dot.h | 2 +- include/triton/dnn/shift.h | 2 +- include/triton/driver/buffer.h | 10 +++--- include/triton/tools/bench.hpp | 4 +-- lib/dnn/base.cpp | 9 ++++-- lib/dnn/blocksparse/dot.cpp | 36 ++++++++++++---------- lib/dnn/conv.cpp | 18 ++++++----- lib/dnn/dot.cpp | 2 +- lib/dnn/shift.cpp | 2 +- lib/driver/buffer.cpp | 26 +++++++++------- lib/driver/module.cpp | 2 +- lib/runtime/jit.cpp | 14 ++++----- 21 files changed, 118 insertions(+), 101 deletions(-) diff --git a/examples/python/tensorflow/batchnorm.cpp b/examples/python/tensorflow/batchnorm.cpp index 3e50aeb8b..137a84809 100644 --- a/examples/python/tensorflow/batchnorm.cpp +++ b/examples/python/tensorflow/batchnorm.cpp @@ -51,12 +51,12 @@ public: OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m)); OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v)); // triton handles - triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat().data(), false); - triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat().data(), false); - triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat().data(), false); - triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat().data(), false); - triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat().data(), false); - triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat().data(), false); + triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false); + triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false); + triton::driver::cu_buffer b(ctx, fw_b.tensor_data().size(), (CUdeviceptr)fw_b.tensor_data().data(), false); + triton::driver::cu_buffer y(ctx, fw_y->tensor_data().size(), (CUdeviceptr)fw_y->tensor_data().data(), false); + triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false); + triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false); // create config triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b}); @@ -117,14 +117,14 @@ public: OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg)); OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db)); // triton handles - triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat().data(), false); - triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat().data(), false); - triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat().data(), false); - triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat().data(), false); - triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat().data(), false); - triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat().data(), false); - triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat().data(), false); + triton::driver::cu_buffer dy(ctx, fw_dy.tensor_data().size(), (CUdeviceptr)fw_dy.tensor_data().data(), false); + triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false); + triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false); + triton::driver::cu_buffer m(ctx, fw_m.tensor_data().size(), (CUdeviceptr)fw_m.tensor_data().data(), false); + triton::driver::cu_buffer v(ctx, fw_v.tensor_data().size(), (CUdeviceptr)fw_v.tensor_data().data(), false); + triton::driver::cu_buffer dx(ctx, fw_dx->tensor_data().size(), (CUdeviceptr)fw_dx->tensor_data().data(), false); + triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false); + triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false); // create config triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32"); batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index 5da231a12..0d37d382d 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -14,6 +14,7 @@ #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/allocation_description.pb.h" using namespace tensorflow; using shape_inference::DimensionHandle; @@ -21,6 +22,7 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; using GPUDevice = Eigen::GpuDevice; + Status XpropShape(InferenceContext* ctx) { int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); @@ -120,23 +122,20 @@ public: shape_c.AddDim(params_.K); Tensor* c = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c)); - // allocate locks - int gridN = (N + 63)/64; - Tensor* locks; - TensorShape shape_l; - if (params_.locks > 0) - shape_l.AddDim(gridN * params_.locks * 2); - OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks)); // wrap tensorflow handles - triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); - triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat().data(), false); - triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat().data(), false); + triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false); + triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); + triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); + triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false); // create profile - triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP); + triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP); // blocksparse matmul - dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING); + triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING); + triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks(); + Tensor *tmp = nullptr; + TensorShape tmp_shapes; + tmp_shapes.AddDim(locks_buffer->size() / 4); + OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp)); } private: diff --git a/examples/python/tensorflow/conv.cpp b/examples/python/tensorflow/conv.cpp index 4ed457021..f06bf679c 100644 --- a/examples/python/tensorflow/conv.cpp +++ b/examples/python/tensorflow/conv.cpp @@ -50,8 +50,8 @@ public: int32_t pad_d = 0, pad_h = 0, pad_w = 0; bool has_bias = false; // wrap buffers - triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat().data(), false); - triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat().data(), false); + triton::driver::cu_buffer a(ctx, tfa.tensor_data().size(), (CUdeviceptr)tfa.tensor_data().data(), false); + triton::driver::cu_buffer b(ctx, tfb.tensor_data().size(), (CUdeviceptr)tfb.tensor_data().data(), false); triton::driver::buffer* bias = nullptr; // template triton::dnn::conv conv(B, C, @@ -68,7 +68,7 @@ public: Tensor* tfc = nullptr; TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}); OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc)); - triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat().data(), false); + triton::driver::cu_buffer c(ctx, tfc->tensor_data().size(), (CUdeviceptr)tfc->tensor_data().data(), false); // enqueue conv.enqueue(stream, {&a, &b, &c, bias}); } diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 368ef8be3..7acedb7e9 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -45,9 +45,9 @@ class DotOp : public OpKernel { if (out_shape.num_elements() == 0) return; // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); + triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false); + triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); + triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); // template triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8); dot.enqueue(stream, {&da, &db, &dc}); diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 2a3973fbb..28e10b679 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -119,9 +119,9 @@ public: if (out_shapes.num_elements() == 0) return; // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); + triton::driver::cu_buffer da(ctx, tf_a.tensor_data().size(), (CUdeviceptr)tf_a.tensor_data().data(), false); + triton::driver::cu_buffer db(ctx, tf_b.tensor_data().size(), (CUdeviceptr)tf_b.tensor_data().data(), false); + triton::driver::cu_buffer dc(ctx, tf_c->tensor_data().size(), (CUdeviceptr)tf_c->tensor_data().data(), false); shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING); } diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index 266f29803..b9e2b886b 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -61,7 +61,7 @@ protected: private: // initialize - virtual void init_impl(driver::stream *, driver::cu_module *) = 0; + virtual void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) = 0; // deinitialize virtual void deinit_impl() = 0; // enqueue @@ -86,7 +86,7 @@ public: // clone virtual base* clone() const = 0; // enqueue - void enqueue(driver::stream* stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); + base* enqueue(driver::stream* stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); // get profile launch_context_t get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index 8f9053225..32c006b99 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -38,7 +38,7 @@ namespace dnn{ class batchnorm_forward: public base { private: // init - void init_impl(driver::stream *, driver::cu_module *) { } + void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { } void deinit_impl() { } // enqueue @@ -74,7 +74,7 @@ private: class batchnorm_backward: public base{ private: // init - void init_impl(driver::stream *, driver::cu_module *) { } + void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { } void deinit_impl() { } // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h index 01c94a2fe..98a1ce6fa 100644 --- a/include/triton/dnn/blocksparse/dot.h +++ b/include/triton/dnn/blocksparse/dot.h @@ -26,14 +26,16 @@ private: std::vector search_space() const; params_t heuristics() const; // init - void init_impl(driver::stream *stream, driver::cu_module *module); + void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); // deinit void deinit_impl(); public: // constructor - dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, op_t op = FPROP); + dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP); // triton-c source void triton_c_src(std::ostream &os) const; + // locks + driver::buffer* get_locks() const; // clone base* clone() const; @@ -46,7 +48,8 @@ private: int32_t K_; int32_t BS_; int32_t nlocks_; - driver::buffer *locks_; + int32_t nblocks_; + std::shared_ptr locks_; op_t op_; }; diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index d81ff872d..2745d72bc 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -24,7 +24,7 @@ private: void build_b_deltas(); void build_a_deltas(); void build_masks(); - void init_impl(driver::stream *, driver::cu_module *); + void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); void deinit_impl() { } // enqueue diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h index 6ba3f0b24..30836357f 100644 --- a/include/triton/dnn/dot.h +++ b/include/triton/dnn/dot.h @@ -9,7 +9,7 @@ namespace dnn{ class dot: public base { private: // initialize - void init_impl(driver::stream *, driver::cu_module *); + void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information); void deinit_impl() { } // enqueue diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 25b9547f3..35ad312e0 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -49,7 +49,7 @@ enum layout_t { class shift: public base { private: // initialize and enqueue - void init_impl(driver::stream *stream, driver::cu_module *module); + void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); void deinit_impl(); void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, diff --git a/include/triton/driver/buffer.h b/include/triton/driver/buffer.h index 0502f1ff4..a0502f789 100755 --- a/include/triton/driver/buffer.h +++ b/include/triton/driver/buffer.h @@ -36,14 +36,16 @@ class stream; // Base class buffer : public polymorphic_resource { public: - buffer(driver::context* ctx, CUdeviceptr cl, bool take_ownership); - buffer(driver::context* ctx, cl_mem cl, bool take_ownership); - buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership); + buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership); + buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership); + buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership); static buffer* create(driver::context* ctx, size_t size); driver::context* context(); + size_t size(); protected: driver::context* context_; + size_t size_; }; // CPU @@ -65,7 +67,7 @@ class cu_buffer: public buffer { public: cu_buffer(driver::context* context, size_t size); - cu_buffer(driver::context* context, CUdeviceptr cu, bool take_ownership); + cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership); void set_zero(triton::driver::stream *queue, size_t size); }; diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 6ac72fec7..6d71d27ae 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -38,8 +38,8 @@ inline double bench(std::function const & op, driver::stream * stream) while(total_time*1e-9 < 1e-3){ float norm = 1; // normalize clock if possible to get roughly constant result - if(auto cu_device = dynamic_cast(device)) - norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); +// if(auto cu_device = dynamic_cast(device)) +// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); tmr.start(); op(); stream->synchronize(); diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index e5aa7ad45..a75334b90 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -44,11 +44,12 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v auto benchmark = [&](triton::driver::kernel* kernel, rt::launch_information info) { // launch info - clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); + clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info); clone->enqueue_impl(stream, kernel, args, info); stream->synchronize(); double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream); clone->deinit_impl(); +// std::cout << ts * 1e-6 << std::endl; return num_flops() / ts * 1e-3; }; // auto-tune and save result @@ -65,7 +66,8 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); - clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); + rt::launch_information info = jit->get_launch_info(name_.c_str()); + clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info); } /* retrieved compiled template */ else { @@ -75,9 +77,10 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v return {it->first, jit}; } -void base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { +base* base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { launch_context_t info = get_launch_context(stream, args, autotune); info.op->enqueue_impl(stream, info.kernel, args, info.info); + return info.op; } launch_context_t base::get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune) { diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index 2f69677a8..ff021cca8 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -6,7 +6,7 @@ namespace blocksparse{ size_t dot::num_flops() const { - + return 2.*nblocks_*BS_*BS_*N_; } bool dot::operator <(const base& other) const { @@ -30,25 +30,23 @@ base * dot::clone() const { } dot::dot(int32_t N, int32_t K, int32_t S, int32_t C, - const std::string& ty, int32_t BS, int32_t nlocks, op_t op): + const std::string& ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op): base("bsdot"), N_(N), K_(K), S_(S), C_(C), ab_ty_(ty), c_ty_(ty), - BS_(BS), nlocks_(nlocks), op_(op){ + BS_(BS), nlocks_(nlocks), nblocks_(nblocks), op_(op){ } -void dot::init_impl(driver::stream *stream, driver::cu_module *module) { -// int32_t TM = info.globals["TM"]; -// size_t grid_0 = (N_ + TM - 1) / TM; -// if(nlocks_){ -// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4); -// ((driver::cu_buffer*)locks_)->set_zero(stream, grid_0 * nlocks_ * 2 * 4); -// } +void dot::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { + int32_t TM = info.globals["TM"]; + size_t grid_0 = (N_ + TM - 1) / TM; + if(nlocks_ && !locks_){ + locks_.reset(triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4)); + ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); + } } void dot::deinit_impl() { -// if(locks_) -// delete locks_; } void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, @@ -57,7 +55,6 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, driver::buffer *b = args[1]; driver::buffer *c = args[2]; driver::buffer *lut = args[3]; - driver::buffer *locks = args[4]; int32_t lda = N_; int32_t ldc = N_; kernel->setArg(0, a); @@ -67,16 +64,20 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(4, ldc); kernel->setArg(5, N_); kernel->setArg(6, lut); - kernel->setArg(7, locks); + kernel->setArg(7, locks_.get()); kernel->setArg(8, nlocks_); int32_t TM = info.globals["TM"]; size_t grid_0 = (N_ + TM - 1) / TM; size_t grid_1 = S_; if(nlocks_) - ((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4); + ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); } +driver::buffer* dot::get_locks() const { + return locks_.get(); +} + void dot::triton_c_src(std::ostream &os) const { std::string usea = (op_ == WGRAD) ? "trans(a)" : "a"; std::string useb = (op_ == FPROP) ? "trans(b)" : "b"; @@ -90,7 +91,7 @@ void dot::triton_c_src(std::ostream &os) const { std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ; std::string result = R"( - const tunable int32 TM = {64}; + const tunable int32 TM = {32, 64, 128}; const tunable int32 TN = {)" + std::to_string(BS_) + R"(}; const tunable int32 TK = {)" + std::to_string(BS_) + R"(}; @@ -106,6 +107,7 @@ void dot::triton_c_src(std::ostream &os) const { int32 ryb[TN] = 0 ... TN; int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; + int1 checka[TM, TK] = (rxa < N)[:, newaxis]; int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(; int32 *header = lut + ridy * 4; @@ -119,7 +121,7 @@ void dot::triton_c_src(std::ostream &os) const { int32 bk = *(plut + 1); )" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda; )" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN; - )" + ab_ty_ + " a[" + sizea + R"(] = *pa; + )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; )" + ab_ty_ + " b[" + sizeb + R"(] = *pb; acc = dot()" + usea + ", " + useb + R"(, acc); plut = plut + 2; diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index c20701a4b..f54c63560 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -278,7 +278,7 @@ size_t conv::num_flops() const{ return 2.*M_*N_*K_; } -void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) { +void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module, triton::runtime::launch_information info) { auto init_lut = [&](bool is_cst, const char *name, std::vector host) -> triton::driver::buffer*{ if(host.empty()) return nullptr; @@ -293,12 +293,16 @@ void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) stream->write(buffer, false, 0, nbytes, host.data()); return buffer; }; - - d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_); - d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_); - d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_); - d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2); - ((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2); + if(d_a_deltas_ == nullptr) + d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_); + if(d_b_deltas_ == nullptr) + d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_); + if(d_masks_ == nullptr) + d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_); + if(d_locks_ == nullptr){ + d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2); + ((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2); + } } void conv::set_arg(driver::kernel *kernel, diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 114ec7450..1b5e061d3 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -39,7 +39,7 @@ base* dot::clone() const { return new dot(*this); } -void dot::init_impl(driver::stream* stream, driver::cu_module *) { +void dot::init_impl(driver::stream* stream, driver::cu_module *, runtime::launch_information) { std::vector hlocks(2048, 0); if(locks_ == nullptr) locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4); diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 58c62dd46..3bf5e1035 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -226,7 +226,7 @@ bool shift::operator <(const base& other) const{ y->bias_); } -void shift::init_impl(driver::stream *stream, driver::cu_module *module) { +void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { build_delta_a(); triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a"); stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data()); diff --git a/lib/driver/buffer.cpp b/lib/driver/buffer.cpp index cf96aa115..111091fdf 100755 --- a/lib/driver/buffer.cpp +++ b/lib/driver/buffer.cpp @@ -36,20 +36,24 @@ namespace driver // -buffer::buffer(driver::context* ctx, CUdeviceptr cu, bool take_ownership) - : polymorphic_resource(cu, take_ownership), context_(ctx) { } +buffer::buffer(driver::context* ctx, size_t size, CUdeviceptr cu, bool take_ownership) + : polymorphic_resource(cu, take_ownership), context_(ctx), size_(size) { } -buffer::buffer(driver::context* ctx, cl_mem cl, bool take_ownership) - : polymorphic_resource(cl, take_ownership), context_(ctx) { } +buffer::buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership) + : polymorphic_resource(cl, take_ownership), context_(ctx), size_(size) { } -buffer::buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership) - : polymorphic_resource(hst, take_ownership), context_(ctx) { } +buffer::buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership) + : polymorphic_resource(hst, take_ownership), context_(ctx), size_(size) { } driver::context* buffer::context() { return context_; } +size_t buffer::size() { + return size_; +} + buffer* buffer::create(driver::context* ctx, size_t size) { switch(ctx->backend()){ case CUDA: return new cu_buffer(ctx, size); @@ -62,14 +66,14 @@ buffer* buffer::create(driver::context* ctx, size_t size) { // host_buffer::host_buffer(driver::context *context, size_t size) - : buffer(context, host_buffer_t(), true){ + : buffer(context, size, host_buffer_t(), true){ hst_->data = new char[size]; } // ocl_buffer::ocl_buffer(driver::context* context, size_t size) - : buffer(context, cl_mem(), true){ + : buffer(context, size, cl_mem(), true){ cl_int err; *cl_ = dispatch::clCreateBuffer(*context->cl(), CL_MEM_READ_WRITE, size, NULL, &err); check(err); @@ -79,13 +83,13 @@ ocl_buffer::ocl_buffer(driver::context* context, size_t size) // cu_buffer::cu_buffer(driver::context* context, size_t size) - : buffer(context, CUdeviceptr(), true) { + : buffer(context, size, CUdeviceptr(), true) { cu_context::context_switcher ctx_switch(*context_); dispatch::cuMemAlloc(&*cu_, size); } -cu_buffer::cu_buffer(driver::context* context, CUdeviceptr cu, bool take_ownership) - : buffer(context, cu, take_ownership){ +cu_buffer::cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership) + : buffer(context, size, cu, take_ownership){ } void cu_buffer::set_zero(driver::stream* queue, size_t size) diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 4ff863666..d2c31fadd 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -275,7 +275,7 @@ cu_buffer* cu_module::symbol(const char *name) const{ CUdeviceptr handle; size_t size; dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name); - return new cu_buffer(ctx_, handle, false); + return new cu_buffer(ctx_, size, handle, false); } diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 6fa727406..12eebdd3c 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools -// ThreadPool pool(nthreads); + ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function -// pool.enqueue(f,values); - f(values); + pool.enqueue(f,values); +// f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return; @@ -51,7 +51,7 @@ void parallel_loop_nest(std::vector const & ranges, } i = D - 1; // Short sleep so that the thread pool doesn't grow too big -// std::this_thread::sleep_for(std::chrono::microseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds(1)); } } @@ -212,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben best.perf = perf; best.params = params; } -// for(size_t i = 0; i < params.size(); i++) -// std::cout << ((i==0)?"":", ") << params[i] << std::flush; -// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; + for(size_t i = 0; i < params.size(); i++) + std::cout << ((i==0)?"":", ") << params[i] << std::flush; + std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; } };