Made sure it works for FP16
This commit is contained in:
@@ -51,12 +51,12 @@ public:
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, fw_b.tensor_data().size(), (CUdeviceptr)fw_b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, fw_y->tensor_data().size(), (CUdeviceptr)fw_y->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
@@ -117,14 +117,14 @@ public:
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dy(ctx, fw_dy.tensor_data().size(), (CUdeviceptr)fw_dy.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, fw_x.tensor_data().size(), (CUdeviceptr)fw_x.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, fw_g.tensor_data().size(), (CUdeviceptr)fw_g.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, fw_m.tensor_data().size(), (CUdeviceptr)fw_m.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v.tensor_data().size(), (CUdeviceptr)fw_v.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, fw_dx->tensor_data().size(), (CUdeviceptr)fw_dx->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
|
@@ -14,6 +14,7 @@
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using shape_inference::DimensionHandle;
|
||||
@@ -21,6 +22,7 @@ using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
|
||||
Status XpropShape(InferenceContext* ctx)
|
||||
{
|
||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||
@@ -120,23 +122,20 @@ public:
|
||||
shape_c.AddDim(params_.K);
|
||||
Tensor* c = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||
// allocate locks
|
||||
int gridN = (N + 63)/64;
|
||||
Tensor* locks;
|
||||
TensorShape shape_l;
|
||||
if (params_.locks > 0)
|
||||
shape_l.AddDim(gridN * params_.locks * 2);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||
// wrap tensorflow handles
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP);
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
tmp_shapes.AddDim(locks_buffer->size() / 4);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -50,8 +50,8 @@ public:
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
bool has_bias = false;
|
||||
// wrap buffers
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer a(ctx, tfa.tensor_data().size(), (CUdeviceptr)tfa.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, tfb.tensor_data().size(), (CUdeviceptr)tfb.tensor_data().data(), false);
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
// template
|
||||
triton::dnn::conv conv(B, C,
|
||||
@@ -68,7 +68,7 @@ public:
|
||||
Tensor* tfc = nullptr;
|
||||
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer c(ctx, tfc->tensor_data().size(), (CUdeviceptr)tfc->tensor_data().data(), false);
|
||||
// enqueue
|
||||
conv.enqueue(stream, {&a, &b, &c, bias});
|
||||
}
|
||||
|
@@ -45,9 +45,9 @@ class DotOp : public OpKernel {
|
||||
if (out_shape.num_elements() == 0)
|
||||
return;
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
// template
|
||||
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
|
@@ -119,9 +119,9 @@ public:
|
||||
if (out_shapes.num_elements() == 0)
|
||||
return;
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, tf_a.tensor_data().size(), (CUdeviceptr)tf_a.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, tf_b.tensor_data().size(), (CUdeviceptr)tf_b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, tf_c->tensor_data().size(), (CUdeviceptr)tf_c->tensor_data().data(), false);
|
||||
shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user