[dnn/blocksparse/dot] prototype version seems to pass basic test
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/dnn/blocksparse/dot.h"
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@@ -20,106 +21,88 @@ using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
|
||||
}
|
||||
)";
|
||||
|
||||
Status XpropShape(InferenceContext* ctx)
|
||||
{
|
||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
||||
int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K));
|
||||
int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis));
|
||||
|
||||
// C ==> K
|
||||
ShapeHandle x = ctx->input(0);
|
||||
int rank = ctx->Rank(x);
|
||||
//printf("XpropShape: %d\n", rank);
|
||||
if (rank > 0)
|
||||
{
|
||||
std::vector<DimensionHandle> shape;
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; i++)
|
||||
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
||||
|
||||
ctx->set_output(0, ctx->MakeShape(shape));
|
||||
}
|
||||
else
|
||||
ctx->set_output(0, ctx->UnknownShape());
|
||||
ctx->set_output(1, ctx->UnknownShape());
|
||||
return Status::OK();
|
||||
// C ==> K
|
||||
ShapeHandle x = ctx->input(0);
|
||||
int rank = ctx->Rank(x);
|
||||
//printf("XpropShape: %d\n", rank);
|
||||
if (rank > 0)
|
||||
{
|
||||
std::vector<DimensionHandle> shape;
|
||||
shape.reserve(rank);
|
||||
for (int i = 0; i < rank; i++)
|
||||
shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i));
|
||||
ctx->set_output(0, ctx->MakeShape(shape));
|
||||
}
|
||||
else
|
||||
ctx->set_output(0, ctx->UnknownShape());
|
||||
ctx->set_output(1, ctx->UnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
REGISTER_OP("BlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
REGISTER_OP("TritonBlocksparseMatmul")
|
||||
.Input("x: T")
|
||||
.Input("w: T")
|
||||
.Input("lut: int64")
|
||||
.Input("lut_dx: int64")
|
||||
.Input("lut_dw: int64")
|
||||
.Input("gate: ngate * float")
|
||||
.Output("y: T")
|
||||
.Output("temp: int32")
|
||||
.Attr("T: {half, float, bfloat16}")
|
||||
.Attr("blocks: int >=0")
|
||||
.Attr("bsize: int")
|
||||
.Attr("segments: int = 0")
|
||||
.Attr("segments_dx: int = 0")
|
||||
.Attr("locks: int = 0")
|
||||
.Attr("locks_dx: int = 0")
|
||||
.Attr("axis: int = 1")
|
||||
.Attr("C: int >=0")
|
||||
.Attr("K: int >=0")
|
||||
.Attr("shared: int = 0")
|
||||
.Attr("shared_dx: int = 0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.0")
|
||||
.Attr("gated_dw: bool = false")
|
||||
.Attr("gate_grad: bool = false")
|
||||
.Attr("bench: int = 0")
|
||||
.Attr("ngate: int >= 0")
|
||||
.SetShapeFn(XpropShape)
|
||||
.Doc(R"doc(
|
||||
Multiply the matrix "a" by the blocksparse matrix "b".
|
||||
)doc");
|
||||
|
||||
|
||||
typedef struct bsmm_params
|
||||
{
|
||||
const int* Lut;
|
||||
const float* Gate;
|
||||
int* Lock;
|
||||
//float4* Scratch;
|
||||
int blocks;
|
||||
int bsize;
|
||||
int segments;
|
||||
int locks;
|
||||
int C;
|
||||
int K;
|
||||
int N;
|
||||
int shared;
|
||||
int pcount;
|
||||
uint blk_a;
|
||||
uint blk_A;
|
||||
uint blk_b;
|
||||
uint blk_B;
|
||||
float alpha;
|
||||
float beta;
|
||||
CUstream stream;
|
||||
const int* Lut;
|
||||
const float* Gate;
|
||||
int* Lock;
|
||||
int blocks;
|
||||
int bsize;
|
||||
int segments;
|
||||
int locks;
|
||||
int C;
|
||||
int K;
|
||||
int N;
|
||||
int shared;
|
||||
int pcount;
|
||||
uint blk_a;
|
||||
uint blk_A;
|
||||
uint blk_b;
|
||||
uint blk_B;
|
||||
float alpha;
|
||||
float beta;
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks ));
|
||||
@@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
// get device/stream
|
||||
GPUDevice device = context->eigen_device<GPUDevice>();
|
||||
triton::driver::cu_stream sstream(device.stream(), false);
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& a = context->input(0);
|
||||
const Tensor& b = context->input(1);
|
||||
const Tensor& lut = context->input(2);
|
||||
// allocate c
|
||||
TensorShape shape_c;
|
||||
int N = 1;
|
||||
int rank_a = a.dims();
|
||||
for (int i = 0; i < rank_a; i++)
|
||||
if (i != axis_) {
|
||||
shape_c.AddDim(a.dim_size(i));
|
||||
N *= a.dim_size(i);
|
||||
}
|
||||
else
|
||||
shape_c.AddDim(params_.K);
|
||||
Tensor* c = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||
// grid and block
|
||||
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
|
||||
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
|
||||
blkN = 64;
|
||||
gridN = (N + 63)/64;
|
||||
}
|
||||
// allocate locks
|
||||
Tensor* locks;
|
||||
TensorShape shape_l;
|
||||
if (params_.locks > 0)
|
||||
shape_l.AddDim(gridN * params_.locks * 2);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||
// initialize default compute device
|
||||
triton::runtime::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
// blocksparse matmul
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -157,4 +185,4 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
|
Reference in New Issue
Block a user