#include #include "triton/driver/buffer.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/jit.h" #define EIGEN_USE_GPU #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/framework/common_shape_fns.h" using namespace tensorflow; using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; using GPUDevice = Eigen::GpuDevice; const char* src = R"( const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {8}; const tunable int32 GZ = {1}; void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, int32 M, int32 N, int32 K, int32 lda, int32 ldb, int32 ldc, int32 *locks, int32 grid0, int32 grid1) { } )"; Status XpropShape(InferenceContext* ctx) { int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); // C ==> K ShapeHandle x = ctx->input(0); int rank = ctx->Rank(x); //printf("XpropShape: %d\n", rank); if (rank > 0) { std::vector shape; shape.reserve(rank); for (int i = 0; i < rank; i++) shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i)); ctx->set_output(0, ctx->MakeShape(shape)); } else ctx->set_output(0, ctx->UnknownShape()); ctx->set_output(1, ctx->UnknownShape()); return Status::OK(); } REGISTER_OP("BlocksparseMatmul") .Input("x: T") .Input("w: T") .Input("lut: int64") .Input("lut_dx: int64") .Input("lut_dw: int64") .Input("gate: ngate * float") .Output("y: T") .Output("temp: int32") .Attr("T: {half, float, bfloat16}") .Attr("blocks: int >=0") .Attr("bsize: int") .Attr("segments: int = 0") .Attr("segments_dx: int = 0") .Attr("locks: int = 0") .Attr("locks_dx: int = 0") .Attr("axis: int = 1") .Attr("C: int >=0") .Attr("K: int >=0") .Attr("shared: int = 0") .Attr("shared_dx: int = 0") .Attr("alpha: float = 1.0") .Attr("beta: float = 0.0") .Attr("gated_dw: bool = false") .Attr("gate_grad: bool = false") .Attr("bench: int = 0") .Attr("ngate: int >= 0") .SetShapeFn(XpropShape) .Doc(R"doc( Multiply the matrix "a" by the blocksparse matrix "b". )doc"); typedef struct bsmm_params { const int* Lut; const float* Gate; int* Lock; //float4* Scratch; int blocks; int bsize; int segments; int locks; int C; int K; int N; int shared; int pcount; uint blk_a; uint blk_A; uint blk_b; uint blk_B; float alpha; float beta; CUstream stream; } bsmm_params; class BlocksparseMatmulOp : public OpKernel { public: explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments)); OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks )); OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks )); OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize )); OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C )); OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K )); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared )); OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha )); OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta )); OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ )); OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ )); OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_)); OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536")); OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536")); params_.pcount = 1; params_.blk_A = 0; is_gpu_ = ctx->device_type() == DEVICE_GPU; if (bench_) { repeat_ = bench_; flops_ = (float)(params_.blocks * params_.bsize*params_.bsize); const char* op = "FPROP"; sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks); } } void Compute(OpKernelContext* context){ } private: bsmm_params params_; int axis_, bench_, repeat_, SMs_, major_, grid_n_; float flops_; bool gated_dw_, is_gpu_; char bench_string_[256]; }; REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp);