Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -51,12 +51,12 @@ public:
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
// triton handles
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
triton::driver::cu_buffer b(ctx, fw_b.tensor_data().size(), (CUdeviceptr)fw_b.tensor_data().data(), false);
triton::driver::cu_buffer y(ctx, fw_y->tensor_data().size(), (CUdeviceptr)fw_y->tensor_data().data(), false);
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
@@ -117,14 +117,14 @@ public:
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
// triton handles
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
triton::driver::cu_buffer dy(ctx, fw_dy.tensor_data().size(), (CUdeviceptr)fw_dy.tensor_data().data(), false);
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
triton::driver::cu_buffer m(ctx, fw_m.tensor_data().size(), (CUdeviceptr)fw_m.tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v.tensor_data().size(), (CUdeviceptr)fw_v.tensor_data().data(), false);
triton::driver::cu_buffer dx(ctx, fw_dx->tensor_data().size(), (CUdeviceptr)fw_dx->tensor_data().data(), false);
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});

View File

@@ -14,6 +14,7 @@
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
using namespace tensorflow;
using shape_inference::DimensionHandle;
@@ -21,6 +22,7 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;
Status XpropShape(InferenceContext* ctx)
{
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
@@ -120,23 +122,20 @@ public:
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// allocate locks
int gridN = (N + 63)/64;
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// wrap tensorflow handles
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
// create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP);
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
// blocksparse matmul
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
Tensor *tmp = nullptr;
TensorShape tmp_shapes;
tmp_shapes.AddDim(locks_buffer->size() / 4);
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
}
private:

View File

@@ -50,8 +50,8 @@ public:
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false;
// wrap buffers
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer a(ctx, tfa.tensor_data().size(), (CUdeviceptr)tfa.tensor_data().data(), false);
triton::driver::cu_buffer b(ctx, tfb.tensor_data().size(), (CUdeviceptr)tfb.tensor_data().data(), false);
triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
@@ -68,7 +68,7 @@ public:
Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
triton::driver::cu_buffer c(ctx, tfc->tensor_data().size(), (CUdeviceptr)tfc->tensor_data().data(), false);
// enqueue
conv.enqueue(stream, {&a, &b, &c, bias});
}

View File

@@ -45,9 +45,9 @@ class DotOp : public OpKernel {
if (out_shape.num_elements() == 0)
return;
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
// template
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
dot.enqueue(stream, {&da, &db, &dc});

View File

@@ -119,9 +119,9 @@ public:
if (out_shapes.num_elements() == 0)
return;
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
triton::driver::cu_buffer da(ctx, tf_a.tensor_data().size(), (CUdeviceptr)tf_a.tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, tf_b.tensor_data().size(), (CUdeviceptr)tf_b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, tf_c->tensor_data().size(), (CUdeviceptr)tf_c->tensor_data().data(), false);
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
}

View File

@@ -61,7 +61,7 @@ protected:
private:
// initialize
virtual void init_impl(driver::stream *, driver::cu_module *) = 0;
virtual void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) = 0;
// deinitialize
virtual void deinit_impl() = 0;
// enqueue
@@ -86,7 +86,7 @@ public:
// clone
virtual base* clone() const = 0;
// enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
base* enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
// get profile
launch_context_t get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune = PARTIAL_TUNING);

View File

@@ -38,7 +38,7 @@ namespace dnn{
class batchnorm_forward: public base {
private:
// init
void init_impl(driver::stream *, driver::cu_module *) { }
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { }
void deinit_impl() { }
// enqueue
@@ -74,7 +74,7 @@ private:
class batchnorm_backward: public base{
private:
// init
void init_impl(driver::stream *, driver::cu_module *) { }
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information) { }
void deinit_impl() { }
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,

View File

@@ -26,14 +26,16 @@ private:
std::vector<params_t> search_space() const;
params_t heuristics() const;
// init
void init_impl(driver::stream *stream, driver::cu_module *module);
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
// deinit
void deinit_impl();
public:
// constructor
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, op_t op = FPROP);
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP);
// triton-c source
void triton_c_src(std::ostream &os) const;
// locks
driver::buffer* get_locks() const;
// clone
base* clone() const;
@@ -46,7 +48,8 @@ private:
int32_t K_;
int32_t BS_;
int32_t nlocks_;
driver::buffer *locks_;
int32_t nblocks_;
std::shared_ptr<driver::buffer> locks_;
op_t op_;
};

View File

@@ -24,7 +24,7 @@ private:
void build_b_deltas();
void build_a_deltas();
void build_masks();
void init_impl(driver::stream *, driver::cu_module *);
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
void deinit_impl() { }
// enqueue

View File

@@ -9,7 +9,7 @@ namespace dnn{
class dot: public base {
private:
// initialize
void init_impl(driver::stream *, driver::cu_module *);
void init_impl(driver::stream *, driver::cu_module *, triton::runtime::launch_information);
void deinit_impl() { }
// enqueue

View File

@@ -49,7 +49,7 @@ enum layout_t {
class shift: public base {
private:
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info);
void deinit_impl();
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,

View File

@@ -36,14 +36,16 @@ class stream;
// Base
class buffer : public polymorphic_resource<CUdeviceptr, cl_mem, host_buffer_t> {
public:
buffer(driver::context* ctx, CUdeviceptr cl, bool take_ownership);
buffer(driver::context* ctx, cl_mem cl, bool take_ownership);
buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership);
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership);
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
static buffer* create(driver::context* ctx, size_t size);
driver::context* context();
size_t size();
protected:
driver::context* context_;
size_t size_;
};
// CPU
@@ -65,7 +67,7 @@ class cu_buffer: public buffer
{
public:
cu_buffer(driver::context* context, size_t size);
cu_buffer(driver::context* context, CUdeviceptr cu, bool take_ownership);
cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership);
void set_zero(triton::driver::stream *queue, size_t size);
};

View File

@@ -38,8 +38,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();

View File

@@ -44,11 +44,12 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
auto benchmark = [&](triton::driver::kernel* kernel,
rt::launch_information info) {
// launch info
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
clone->enqueue_impl(stream, kernel, args, info);
stream->synchronize();
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
clone->deinit_impl();
// std::cout << ts * 1e-6 << std::endl;
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
@@ -65,7 +66,8 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
rt::launch_information info = jit->get_launch_info(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
}
/* retrieved compiled template */
else {
@@ -75,9 +77,10 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
return {it->first, jit};
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
base* base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
launch_context_t info = get_launch_context(stream, args, autotune);
info.op->enqueue_impl(stream, info.kernel, args, info.info);
return info.op;
}
launch_context_t base::get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {

View File

@@ -6,7 +6,7 @@ namespace blocksparse{
size_t dot::num_flops() const {
return 2.*nblocks_*BS_*BS_*N_;
}
bool dot::operator <(const base& other) const {
@@ -30,25 +30,23 @@ base * dot::clone() const {
}
dot::dot(int32_t N, int32_t K, int32_t S, int32_t C,
const std::string& ty, int32_t BS, int32_t nlocks, op_t op):
const std::string& ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op):
base("bsdot"),
N_(N), K_(K), S_(S), C_(C),
ab_ty_(ty), c_ty_(ty),
BS_(BS), nlocks_(nlocks), op_(op){
BS_(BS), nlocks_(nlocks), nblocks_(nblocks), op_(op){
}
void dot::init_impl(driver::stream *stream, driver::cu_module *module) {
// int32_t TM = info.globals["TM"];
// size_t grid_0 = (N_ + TM - 1) / TM;
// if(nlocks_){
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
// ((driver::cu_buffer*)locks_)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
// }
void dot::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
int32_t TM = info.globals["TM"];
size_t grid_0 = (N_ + TM - 1) / TM;
if(nlocks_ && !locks_){
locks_.reset(triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4));
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
}
}
void dot::deinit_impl() {
// if(locks_)
// delete locks_;
}
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
@@ -57,7 +55,6 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
driver::buffer *b = args[1];
driver::buffer *c = args[2];
driver::buffer *lut = args[3];
driver::buffer *locks = args[4];
int32_t lda = N_;
int32_t ldc = N_;
kernel->setArg(0, a);
@@ -67,16 +64,20 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(4, ldc);
kernel->setArg(5, N_);
kernel->setArg(6, lut);
kernel->setArg(7, locks);
kernel->setArg(7, locks_.get());
kernel->setArg(8, nlocks_);
int32_t TM = info.globals["TM"];
size_t grid_0 = (N_ + TM - 1) / TM;
size_t grid_1 = S_;
if(nlocks_)
((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
}
driver::buffer* dot::get_locks() const {
return locks_.get();
}
void dot::triton_c_src(std::ostream &os) const {
std::string usea = (op_ == WGRAD) ? "trans(a)" : "a";
std::string useb = (op_ == FPROP) ? "trans(b)" : "b";
@@ -90,7 +91,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
std::string result =
R"(
const tunable int32 TM = {64};
const tunable int32 TM = {32, 64, 128};
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
@@ -106,6 +107,7 @@ void dot::triton_c_src(std::ostream &os) const {
int32 ryb[TN] = 0 ... TN;
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
int32 *header = lut + ridy * 4;
@@ -119,7 +121,7 @@ void dot::triton_c_src(std::ostream &os) const {
int32 bk = *(plut + 1);
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
)" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN;
)" + ab_ty_ + " a[" + sizea + R"(] = *pa;
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
)" + ab_ty_ + " b[" + sizeb + R"(] = *pb;
acc = dot()" + usea + ", " + useb + R"(, acc);
plut = plut + 2;

View File

@@ -278,7 +278,7 @@ size_t conv::num_flops() const{
return 2.*M_*N_*K_;
}
void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) {
void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module, triton::runtime::launch_information info) {
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
if(host.empty())
return nullptr;
@@ -293,12 +293,16 @@ void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module)
stream->write(buffer, false, 0, nbytes, host.data());
return buffer;
};
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
if(d_a_deltas_ == nullptr)
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
if(d_b_deltas_ == nullptr)
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
if(d_masks_ == nullptr)
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
if(d_locks_ == nullptr){
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
}
}
void conv::set_arg(driver::kernel *kernel,

View File

@@ -39,7 +39,7 @@ base* dot::clone() const {
return new dot(*this);
}
void dot::init_impl(driver::stream* stream, driver::cu_module *) {
void dot::init_impl(driver::stream* stream, driver::cu_module *, runtime::launch_information) {
std::vector<int32_t> hlocks(2048, 0);
if(locks_ == nullptr)
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);

View File

@@ -226,7 +226,7 @@ bool shift::operator <(const base& other) const{
y->bias_);
}
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
build_delta_a();
triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a");
stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data());

View File

@@ -36,20 +36,24 @@ namespace driver
//
buffer::buffer(driver::context* ctx, CUdeviceptr cu, bool take_ownership)
: polymorphic_resource(cu, take_ownership), context_(ctx) { }
buffer::buffer(driver::context* ctx, size_t size, CUdeviceptr cu, bool take_ownership)
: polymorphic_resource(cu, take_ownership), context_(ctx), size_(size) { }
buffer::buffer(driver::context* ctx, cl_mem cl, bool take_ownership)
: polymorphic_resource(cl, take_ownership), context_(ctx) { }
buffer::buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership)
: polymorphic_resource(cl, take_ownership), context_(ctx), size_(size) { }
buffer::buffer(driver::context* ctx, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), context_(ctx) { }
buffer::buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), context_(ctx), size_(size) { }
driver::context* buffer::context() {
return context_;
}
size_t buffer::size() {
return size_;
}
buffer* buffer::create(driver::context* ctx, size_t size) {
switch(ctx->backend()){
case CUDA: return new cu_buffer(ctx, size);
@@ -62,14 +66,14 @@ buffer* buffer::create(driver::context* ctx, size_t size) {
//
host_buffer::host_buffer(driver::context *context, size_t size)
: buffer(context, host_buffer_t(), true){
: buffer(context, size, host_buffer_t(), true){
hst_->data = new char[size];
}
//
ocl_buffer::ocl_buffer(driver::context* context, size_t size)
: buffer(context, cl_mem(), true){
: buffer(context, size, cl_mem(), true){
cl_int err;
*cl_ = dispatch::clCreateBuffer(*context->cl(), CL_MEM_READ_WRITE, size, NULL, &err);
check(err);
@@ -79,13 +83,13 @@ ocl_buffer::ocl_buffer(driver::context* context, size_t size)
//
cu_buffer::cu_buffer(driver::context* context, size_t size)
: buffer(context, CUdeviceptr(), true) {
: buffer(context, size, CUdeviceptr(), true) {
cu_context::context_switcher ctx_switch(*context_);
dispatch::cuMemAlloc(&*cu_, size);
}
cu_buffer::cu_buffer(driver::context* context, CUdeviceptr cu, bool take_ownership)
: buffer(context, cu, take_ownership){
cu_buffer::cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership)
: buffer(context, size, cu, take_ownership){
}
void cu_buffer::set_zero(driver::stream* queue, size_t size)

View File

@@ -275,7 +275,7 @@ cu_buffer* cu_module::symbol(const char *name) const{
CUdeviceptr handle;
size_t size;
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
return new cu_buffer(ctx_, handle, false);
return new cu_buffer(ctx_, size, handle, false);
}

View File

@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
// ThreadPool pool(nthreads);
ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
// pool.enqueue(f,values);
f(values);
pool.enqueue(f,values);
// f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
@@ -51,7 +51,7 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
}
i = D - 1;
// Short sleep so that the thread pool doesn't grow too big
// std::this_thread::sleep_for(std::chrono::microseconds(1));
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
}
@@ -212,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
};