Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -14,6 +14,7 @@
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
using namespace tensorflow;
using shape_inference::DimensionHandle;
@@ -21,6 +22,7 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;
Status XpropShape(InferenceContext* ctx)
{
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
@@ -120,23 +122,20 @@ public:
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// allocate locks
int gridN = (N + 63)/64;
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// wrap tensorflow handles
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
triton::driver::cu_buffer da(ctx, a.tensor_data().size(), (CUdeviceptr)a.tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
// create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks, OP);
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
// blocksparse matmul
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
Tensor *tmp = nullptr;
TensorShape tmp_shapes;
tmp_shapes.AddDim(locks_buffer->size() / 4);
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
}
private: