diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index a6b05df7f..85e73d033 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -15,6 +15,9 @@ #include "tensorflow/core/framework/common_shape_fns.h" using namespace tensorflow; +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; using GPUDevice = Eigen::GpuDevice; @@ -25,139 +28,133 @@ const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {8}; const tunable int32 GZ = {1}; -void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, +void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, int32 M, int32 N, int32 K, int32 lda, int32 ldb, int32 ldc, int32 *locks, int32 grid0, int32 grid1) { - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); - int32 rz = get_global_range[1](2); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 c[TM, TN] = 0; - int32 div = K / GZ; - int32 rem = K % GZ; - K = select(rz < rem, div - 1, div); - int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); - fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; - fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; - fp32 a[TM, TK] = *pa; - fp32 b[TN, TK] = *pb; - int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; - int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; - last_a = last_a / TK * TK; - last_b = last_b / TK * TK; - int32 bound = K - max(last_a, last_b); - for(int32 k = K; k > bound; k = k - TK){ - c = dot(a, trans(b), c); - pa = pa + TK*lda; - pb = pb + TK*ldb; - a = *pa; - b = *pb; - } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1); - for(int32 k = bound; k > 0; k = k - 1){ - int1 checka[TM, 1] = rxc[:, newaxis] < M; - int1 checkb[TN, 1] = ryc[:, newaxis] < N; - fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; - fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; - fp32 a[TM, 1] = checka ? *pa : 0; - fp32 b[TN, 1] = checkb ? *pb : 0; - c = dot(a, trans(b), c); - } - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - int32 *plock = locks + ridx + ridy*grid0; - while(__atomic_cas(plock, 0, 1)); - int32 *pcount = plock + grid0*grid1; - int32 count = *pcount; - int32 countp1 = select(count == GZ - 1, 0, count + 1); - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - if(count == 0) { - @checkc *pc = c; - *pcount = countp1; - } - else { - @checkc *pc = c + *pc; - *pcount = countp1; - } - __atomic_cas(plock, 1, 0); + } )"; -REGISTER_OP("BlockSparseMatMul") - .Input("a: T") - .Input("b: T") - .Input("locks: int32") - .Output("c: T") - .Attr("T: {float}") -; +Status XpropShape(InferenceContext* ctx) +{ + int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); + int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); -class BlockSparseGemmOp : public OpKernel { + // C ==> K + ShapeHandle x = ctx->input(0); + int rank = ctx->Rank(x); + //printf("XpropShape: %d\n", rank); + if (rank > 0) + { + std::vector shape; + shape.reserve(rank); + for (int i = 0; i < rank; i++) + shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i)); + + ctx->set_output(0, ctx->MakeShape(shape)); + } + else + ctx->set_output(0, ctx->UnknownShape()); + ctx->set_output(1, ctx->UnknownShape()); + return Status::OK(); +} + + +REGISTER_OP("BlocksparseMatmul") + .Input("x: T") + .Input("w: T") + .Input("lut: int64") + .Input("lut_dx: int64") + .Input("lut_dw: int64") + .Input("gate: ngate * float") + .Output("y: T") + .Output("temp: int32") + .Attr("T: {half, float, bfloat16}") + .Attr("blocks: int >=0") + .Attr("bsize: int") + .Attr("segments: int = 0") + .Attr("segments_dx: int = 0") + .Attr("locks: int = 0") + .Attr("locks_dx: int = 0") + .Attr("axis: int = 1") + .Attr("C: int >=0") + .Attr("K: int >=0") + .Attr("shared: int = 0") + .Attr("shared_dx: int = 0") + .Attr("alpha: float = 1.0") + .Attr("beta: float = 0.0") + .Attr("gated_dw: bool = false") + .Attr("gate_grad: bool = false") + .Attr("bench: int = 0") + .Attr("ngate: int >= 0") + .SetShapeFn(XpropShape) + .Doc(R"doc( +Multiply the matrix "a" by the blocksparse matrix "b". +)doc"); + + +typedef struct bsmm_params +{ + const int* Lut; + const float* Gate; + int* Lock; + //float4* Scratch; + int blocks; + int bsize; + int segments; + int locks; + int C; + int K; + int N; + int shared; + int pcount; + uint blk_a; + uint blk_A; + uint blk_b; + uint blk_B; + float alpha; + float beta; + CUstream stream; +} bsmm_params; + +class BlocksparseMatmulOp : public OpKernel { public: - explicit BlockSparseGemmOp(OpKernelConstruction* context) : OpKernel(context) { + explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_)); + OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536")); + OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536")); + params_.pcount = 1; + params_.blk_A = 0; + is_gpu_ = ctx->device_type() == DEVICE_GPU; + if (bench_) { + repeat_ = bench_; + flops_ = (float)(params_.blocks * params_.bsize*params_.bsize); + const char* op = "FPROP"; + sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks); + } } void Compute(OpKernelContext* context){ - // get device/stream - GPUDevice device = context->eigen_device(); - triton::driver::cu_stream sstream(device.stream(), false); - triton::driver::context* ctx = sstream.context(); - triton::driver::stream* stream = &sstream; - // get inputs - const Tensor& a = context->input(0); - const Tensor& b = context->input(1); - const Tensor& locks = context->input(2); - // get shapes - const int32_t M = a.dim_size(0); - const int32_t N = b.dim_size(0); - const int32_t K = a.dim_size(1); - // allocate output - Tensor* c = nullptr; - TensorShape out_shape({(int64)M, (int64)N}); - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); - // return early if possible - if (out_shape.num_elements() == 0) - return; - // initialize default compute device - triton::jit jit(ctx); - // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); - triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat().data(), false); - stream->synchronize(); - // just-in-time compile source-code - jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1}); - triton::driver::kernel* kernel = jit.get_function("matmul"); - triton::jit::launch_information info = jit.get_launch_info("matmul"); - // launch info - unsigned TM = info.global_range_size[0]; - unsigned TN = info.global_range_size[1]; - unsigned nthreads = info.num_threads; - unsigned GZ = jit.get_int("GZ"); - std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ}; - // set argument - kernel->setArg(0, *da.cu()); - kernel->setArg(1, *db.cu()); - kernel->setArg(2, *dc.cu()); - kernel->setArg(3, M); - kernel->setArg(4, N); - kernel->setArg(5, K); - kernel->setArg(6, M); - kernel->setArg(7, N); - kernel->setArg(8, M); - kernel->setArg(9, *dlocks.cu()); - kernel->setArg(10, grid[0]); - kernel->setArg(11, grid[1]); - stream->enqueue(kernel, grid, {nthreads, 1, 1}); } private: + bsmm_params params_; + int axis_, bench_, repeat_, SMs_, major_, grid_n_; + float flops_; + bool gated_dw_, is_gpu_; + char bench_string_[256]; }; -REGISTER_KERNEL_BUILDER(Name("BlockSparseMatMul").Device(DEVICE_GPU).TypeConstraint("T"), BlockSparseGemmOp); +REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); diff --git a/examples/python/tensorflow/blocksparse.py b/examples/python/tensorflow/blocksparse.py deleted file mode 100644 index 5a721def9..000000000 --- a/examples/python/tensorflow/blocksparse.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import tensorflow as tf -import numpy as np - -data_files_path = tf.resource_loader.get_data_files_path() -library_dir = '/home/philippe/Development/triton/build/examples/python/tensorflow' -module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) - -M, N, K = 512, 512, 512 -a = tf.placeholder(tf.float32, shape=[M, K]) -b = tf.placeholder(tf.float32, shape=[N, K]) -locks = tf.placeholder(tf.int32, shape=[4096]) -c = module.block_sparse_mat_mul(a, b, locks) -# Run -sess = tf.InteractiveSession() -sess.run(tf.global_variables_initializer()) -result = sess.run([c], feed_dict = {locks: np.zeros(4096), - a: np.random.rand(M, K), - b: np.random.rand(N, K)}) -print(result) diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp new file mode 100644 index 000000000..bc87c71e8 --- /dev/null +++ b/examples/python/tensorflow/dot.cpp @@ -0,0 +1,163 @@ +#include + +#include "triton/driver/buffer.h" +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include "triton/jit.h" + +#define EIGEN_USE_GPU +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +using namespace tensorflow; +using GPUDevice = Eigen::GpuDevice; + + +const char* src = +R"( +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TK = {8}; +const tunable int32 GZ = {1}; + +void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, + int32 M, int32 N, int32 K, + int32 lda, int32 ldb, int32 ldc, + int32 *locks, int32 grid0, int32 grid1) { + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rz = get_global_range[1](2); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 c[TM, TN] = 0; + int32 div = K / GZ; + int32 rem = K % GZ; + K = select(rz < rem, div - 1, div); + int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); + fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; + fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; + fp32 a[TM, TK] = *pa; + fp32 b[TN, TK] = *pb; + int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; + int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; + last_a = last_a / TK * TK; + last_b = last_b / TK * TK; + int32 bound = K - max(last_a, last_b); + for(int32 k = K; k > bound; k = k - TK){ + c = dot(a, trans(b), c); + pa = pa + TK*lda; + pb = pb + TK*ldb; + a = *pa; + b = *pb; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + for(int32 k = bound; k > 0; k = k - 1){ + int1 checka[TM, 1] = rxc[:, newaxis] < M; + int1 checkb[TN, 1] = ryc[:, newaxis] < N; + fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; + fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; + fp32 a[TM, 1] = checka ? *pa : 0; + fp32 b[TN, 1] = checkb ? *pb : 0; + c = dot(a, trans(b), c); + } + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + int32 *plock = locks + ridx + ridy*grid0; + while(__atomic_cas(plock, 0, 1)); + int32 *pcount = plock + grid0*grid1; + int32 count = *pcount; + int32 countp1 = select(count == GZ - 1, 0, count + 1); + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + if(count == 0) { + @checkc *pc = c; + *pcount = countp1; + } + else { + @checkc *pc = c + *pc; + *pcount = countp1; + } + __atomic_cas(plock, 1, 0); +} +)"; + +REGISTER_OP("Dot") + .Input("a: T") + .Input("b: T") + .Input("locks: int32") + .Output("c: T") + .Attr("T: {float}") +; + +class BlockSparseGemmOp : public OpKernel { + public: + explicit BlockSparseGemmOp(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context){ + // get device/stream + GPUDevice device = context->eigen_device(); + triton::driver::cu_stream sstream(device.stream(), false); + triton::driver::context* ctx = sstream.context(); + triton::driver::stream* stream = &sstream; + // get inputs + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + const Tensor& locks = context->input(2); + // get shapes + const int32_t M = a.dim_size(0); + const int32_t N = b.dim_size(0); + const int32_t K = a.dim_size(1); + // allocate output + Tensor* c = nullptr; + TensorShape out_shape({(int64)M, (int64)N}); + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c)); + // return early if possible + if (out_shape.num_elements() == 0) + return; + // initialize default compute device + triton::jit jit(ctx); + // matrix multiplication parameters + triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); + triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); + triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat().data(), false); + stream->synchronize(); + // just-in-time compile source-code + jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1}); + triton::driver::kernel* kernel = jit.get_function("matmul"); + triton::jit::launch_information info = jit.get_launch_info("matmul"); + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + unsigned GZ = jit.get_int("GZ"); + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ}; + // set argument + kernel->setArg(0, *da.cu()); + kernel->setArg(1, *db.cu()); + kernel->setArg(2, *dc.cu()); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, M); + kernel->setArg(7, N); + kernel->setArg(8, M); + kernel->setArg(9, *dlocks.cu()); + kernel->setArg(10, grid[0]); + kernel->setArg(11, grid[1]); + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + } + +private: +}; + +REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU).TypeConstraint("T"), BlockSparseGemmOp);