[dnn/blocksparse/dot] prototype version seems to pass basic test

This commit is contained in:
Philippe Tillet
2019-07-27 21:21:36 -07:00
parent 2a377bc8b1
commit 17cb2db356
18 changed files with 402 additions and 205 deletions

View File

@@ -3,7 +3,8 @@
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/dnn/blocksparse/dot.h"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using GPUDevice = Eigen::GpuDevice;
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
}
)";
Status XpropShape(InferenceContext* ctx)
{
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
// C ==> K
ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank);
if (rank > 0)
{
std::vector<DimensionHandle> shape;
shape.reserve(rank);
for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
ctx->set_output(0, ctx->MakeShape(shape));
}
else
ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape());
return Status::OK();
// C ==> K
ShapeHandle x = ctx->input(0);
int rank = ctx->Rank(x);
//printf("XpropShape: %d\n", rank);
if (rank > 0)
{
std::vector<DimensionHandle> shape;
shape.reserve(rank);
for (int i = 0; i < rank; i++)
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
ctx->set_output(0, ctx->MakeShape(shape));
}
else
ctx->set_output(0, ctx->UnknownShape());
ctx->set_output(1, ctx->UnknownShape());
return Status::OK();
}
REGISTER_OP("BlocksparseMatmul")
.Input("x: T")
.Input("w: T")
.Input("lut: int64")
.Input("lut_dx: int64")
.Input("lut_dw: int64")
.Input("gate: ngate * float")
.Output("y: T")
.Output("temp: int32")
.Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0")
.Attr("bsize: int")
.Attr("segments: int = 0")
.Attr("segments_dx: int = 0")
.Attr("locks: int = 0")
.Attr("locks_dx: int = 0")
.Attr("axis: int = 1")
.Attr("C: int >=0")
.Attr("K: int >=0")
.Attr("shared: int = 0")
.Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false")
.Attr("bench: int = 0")
.Attr("ngate: int >= 0")
.SetShapeFn(XpropShape)
.Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b".
)doc");
REGISTER_OP("TritonBlocksparseMatmul")
.Input("x: T")
.Input("w: T")
.Input("lut: int64")
.Input("lut_dx: int64")
.Input("lut_dw: int64")
.Input("gate: ngate * float")
.Output("y: T")
.Output("temp: int32")
.Attr("T: {half, float, bfloat16}")
.Attr("blocks: int >=0")
.Attr("bsize: int")
.Attr("segments: int = 0")
.Attr("segments_dx: int = 0")
.Attr("locks: int = 0")
.Attr("locks_dx: int = 0")
.Attr("axis: int = 1")
.Attr("C: int >=0")
.Attr("K: int >=0")
.Attr("shared: int = 0")
.Attr("shared_dx: int = 0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.0")
.Attr("gated_dw: bool = false")
.Attr("gate_grad: bool = false")
.Attr("bench: int = 0")
.Attr("ngate: int >= 0")
.SetShapeFn(XpropShape)
.Doc(R"doc(
Multiply the matrix "a" by the blocksparse matrix "b".
)doc");
typedef struct bsmm_params
{
const int* Lut;
const float* Gate;
int* Lock;
//float4* Scratch;
int blocks;
int bsize;
int segments;
int locks;
int C;
int K;
int N;
int shared;
int pcount;
uint blk_a;
uint blk_A;
uint blk_b;
uint blk_B;
float alpha;
float beta;
CUstream stream;
const int* Lut;
const float* Gate;
int* Lock;
int blocks;
int bsize;
int segments;
int locks;
int C;
int K;
int N;
int shared;
int pcount;
uint blk_a;
uint blk_A;
uint blk_b;
uint blk_B;
float alpha;
float beta;
CUstream stream;
} bsmm_params;
class BlocksparseMatmulOp : public OpKernel {
public:
public:
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", &params_.segments));
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", &params_.locks ));
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
}
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& lut = context->input(2);
// allocate c
TensorShape shape_c;
int N = 1;
int rank_a = a.dims();
for (int i = 0; i < rank_a; i++)
if (i != axis_) {
shape_c.AddDim(a.dim_size(i));
N *= a.dim_size(i);
}
else
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// grid and block
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
blkN = 64;
gridN = (N + 63)/64;
}
// allocate locks
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// initialize default compute device
triton::runtime::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
// blocksparse matmul
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
}
private:
@@ -157,4 +185,4 @@ private:
char bench_string_[256];
};
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);