[dnn][blocksparse] added dw code
This commit is contained in:
@@ -46,6 +46,19 @@ Status XpropShape(InferenceContext* ctx)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UpdatShape(InferenceContext* ctx)
|
||||
{
|
||||
//printf("UpdatShape: %d\n", ctx->Rank(ctx->input(0)));
|
||||
|
||||
int blocks, bsize;
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("blocks", &blocks));
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("bsize", &bsize));
|
||||
|
||||
// (blocks, block_size, block_size)
|
||||
DimensionHandle bsize_dim = ctx->MakeDim(bsize);
|
||||
ctx->set_output(0, ctx->MakeShape({ ctx->MakeDim(blocks), bsize_dim, bsize_dim }));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
typedef struct bsmm_params
|
||||
{
|
||||
@@ -72,34 +85,46 @@ typedef struct bsmm_params
|
||||
|
||||
template<triton::dnn::blocksparse::op_t OP, typename T>
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_));
|
||||
OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536"));
|
||||
OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536"));
|
||||
params_.pcount = 1;
|
||||
params_.blk_A = 0;
|
||||
is_gpu_ = ctx->device_type() == DEVICE_GPU;
|
||||
if (bench_) {
|
||||
repeat_ = bench_;
|
||||
flops_ = (float)(params_.blocks * params_.bsize*params_.bsize);
|
||||
const char* op = "FPROP";
|
||||
sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks);
|
||||
}
|
||||
private:
|
||||
void ComputeDw(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// extract input
|
||||
OpInputList x, dy, gate;
|
||||
context->input_list( "x", &x);
|
||||
context->input_list( "dy", &dy);
|
||||
context->input_list("gate", &gate);
|
||||
// sanity checks
|
||||
params_.pcount = x.size();
|
||||
if (params_.pcount > 1)
|
||||
errors::Internal("No more than 1 input allowed.");
|
||||
if (params_.beta != 0.0f || params_.alpha != 1.0f)
|
||||
errors::Internal("Not supported yet");
|
||||
// N
|
||||
int N = 1;
|
||||
int rank = x[0].dims();
|
||||
for (int i = 0; i < rank; i++)
|
||||
if (i != axis_)
|
||||
N *= x[0].dim_size(i);
|
||||
// allocate output
|
||||
Tensor* C;
|
||||
TensorShape shapeC({ params_.blocks, params_.bsize, params_.bsize });
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shapeC, &C));
|
||||
// wrap tensorflow handles
|
||||
triton::driver::cu_buffer da(ctx, x[0].tensor_data().size(), (CUdeviceptr)x[0].tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, dy[0].tensor_data().size(), (CUdeviceptr)dy[0].tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, C->tensor_data().size(), (CUdeviceptr)C->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, context->input(params_.pcount*2).tensor_data().size(), (CUdeviceptr)context->input(params_.pcount*2).tensor_data().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// enqueue
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
void ComputeYDx(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
@@ -129,8 +154,8 @@ public:
|
||||
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::PARTIAL_TUNING);
|
||||
// enqueue
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
@@ -138,6 +163,41 @@ public:
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp));
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ ));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_));
|
||||
OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536"));
|
||||
OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536"));
|
||||
params_.pcount = 1;
|
||||
params_.blk_A = 0;
|
||||
is_gpu_ = ctx->device_type() == DEVICE_GPU;
|
||||
if (bench_) {
|
||||
repeat_ = bench_;
|
||||
flops_ = (float)(params_.blocks * params_.bsize*params_.bsize);
|
||||
const char* op = "FPROP";
|
||||
sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks);
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override{
|
||||
if(OP == triton::dnn::blocksparse::WGRAD)
|
||||
ComputeDw(context);
|
||||
else
|
||||
ComputeYDx(context);
|
||||
}
|
||||
|
||||
private:
|
||||
bsmm_params params_;
|
||||
int axis_, bench_, repeat_, SMs_, major_, grid_n_;
|
||||
@@ -212,3 +272,33 @@ Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::BPROP, Eigen::half>);
|
||||
|
||||
|
||||
REGISTER_OP("TritonBlocksparseMatmulDW")
|
||||
.Input("x: params * T")
|
||||
.Input("dy: params * T")
|
||||
.Input("lut: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("dw: T")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("params: int")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(UpdatShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint<float>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::WGRAD, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),BlocksparseMatmulOp<triton::dnn::blocksparse::WGRAD, Eigen::half>);
|
||||
|
Reference in New Issue
Block a user