diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index efebc102e..771e44c1f 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -4,7 +4,7 @@ #include "triton/runtime/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" -#include "triton/dnn/gemm.h" +#include "triton/dnn/dot.h" #include "triton/tools/bench.hpp" #include "cuda.h" @@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); // benchmark cublas NumericT alpha = 1; NumericT beta = 0; @@ -98,8 +98,9 @@ int main() { // shapes to benchmark std::vector configs = { // {false, false, 8192, 512, 512}, - {false, true, 8192, 8192, 8192} -// {false, true, 32768, 256, 512} +// {false, true, 8192, 8192, 8192} + {false, true, 32768, 256, 256}, + {false, true, 32768, 256, 512} // {true, false, 8192, 512, 512}, // {true, true, 8192, 512, 512} }; diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 91ed2daaa..38e0e37bf 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream, stream->write(dc, true, 0, hc); stream->synchronize(); // benchmark triton - double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); // benchmark cublas - NumericT alpha = 1; - NumericT beta = 0; - cublasGemmAlgo_t fastest; - cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), - &alpha, da, shift.lda(), - db, shift.ldb(), &beta, - dc, shift.ldc(), &fastest); - double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), - &alpha, da, shift.lda(), - db, shift.ldb(), - &beta, dc, shift.ldc(), nullptr, fastest); }, stream); +// NumericT alpha = 1; +// NumericT beta = 0; +// cublasGemmAlgo_t fastest; +// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), +// &alpha, da, shift.lda(), +// db, shift.ldb(), &beta, +// dc, shift.ldc(), &fastest); +// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(), +// &alpha, da, shift.lda(), +// db, shift.ldb(), +// &beta, dc, shift.ldc(), nullptr, fastest); }, stream); // result auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; }; perf_t result; - result.cublas = tflops(cublas_ns); +// result.cublas = tflops(cublas_ns); result.triton = tflops(triton_ns); delete da; delete db; @@ -133,8 +133,9 @@ int main() { {128, 1024, 8, 8, 3, 3, 1024, 1, 1} }; for(config_t c: resnet18){ - for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}) + for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){ configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"}); + } } // initialize default compute device diff --git a/examples/python/tensorflow/CMakeLists.txt b/examples/python/tensorflow/CMakeLists.txt index 5c151f19b..0dad37f19 100644 --- a/examples/python/tensorflow/CMakeLists.txt +++ b/examples/python/tensorflow/CMakeLists.txt @@ -5,7 +5,7 @@ if(${TensorFlow_FOUND}) include_directories("${CUDA_HOME}/include") link_directories(${TF_LIB}) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}) - add_library(tf_blocksparse SHARED dot.cpp conv.cpp shift.cpp batchnorm.cpp) + add_library(tf_blocksparse SHARED blocksparse.cpp dot.cpp conv.cpp shift.cpp batchnorm.cpp) target_link_libraries(tf_blocksparse tensorflow_framework triton) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py ${CMAKE_CURRENT_BINARY_DIR}/run.py diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index 85e73d033..b86c6bcab 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -3,7 +3,8 @@ #include "triton/driver/buffer.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" -#include "triton/jit.h" +#include "triton/runtime/jit.h" +#include "triton/dnn/blocksparse/dot.h" #define EIGEN_USE_GPU #include "tensorflow/core/framework/op.h" @@ -20,106 +21,88 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; using GPUDevice = Eigen::GpuDevice; - -const char* src = -R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; -const tunable int32 TK = {8}; -const tunable int32 GZ = {1}; - -void bsmm (restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, - int32 M, int32 N, int32 K, - int32 lda, int32 ldb, int32 ldc, - int32 *locks, int32 grid0, int32 grid1) { - -} -)"; - Status XpropShape(InferenceContext* ctx) { - int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); - int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); + int K; TF_RETURN_IF_ERROR(ctx->GetAttr( "K", &K)); + int axis; TF_RETURN_IF_ERROR(ctx->GetAttr("axis", &axis)); - // C ==> K - ShapeHandle x = ctx->input(0); - int rank = ctx->Rank(x); - //printf("XpropShape: %d\n", rank); - if (rank > 0) - { - std::vector shape; - shape.reserve(rank); - for (int i = 0; i < rank; i++) - shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i)); - - ctx->set_output(0, ctx->MakeShape(shape)); - } - else - ctx->set_output(0, ctx->UnknownShape()); - ctx->set_output(1, ctx->UnknownShape()); - return Status::OK(); + // C ==> K + ShapeHandle x = ctx->input(0); + int rank = ctx->Rank(x); + //printf("XpropShape: %d\n", rank); + if (rank > 0) + { + std::vector shape; + shape.reserve(rank); + for (int i = 0; i < rank; i++) + shape.push_back(i == axis ? ctx->MakeDim(K) : ctx->Dim(x, i)); + ctx->set_output(0, ctx->MakeShape(shape)); + } + else + ctx->set_output(0, ctx->UnknownShape()); + ctx->set_output(1, ctx->UnknownShape()); + return Status::OK(); } -REGISTER_OP("BlocksparseMatmul") - .Input("x: T") - .Input("w: T") - .Input("lut: int64") - .Input("lut_dx: int64") - .Input("lut_dw: int64") - .Input("gate: ngate * float") - .Output("y: T") - .Output("temp: int32") - .Attr("T: {half, float, bfloat16}") - .Attr("blocks: int >=0") - .Attr("bsize: int") - .Attr("segments: int = 0") - .Attr("segments_dx: int = 0") - .Attr("locks: int = 0") - .Attr("locks_dx: int = 0") - .Attr("axis: int = 1") - .Attr("C: int >=0") - .Attr("K: int >=0") - .Attr("shared: int = 0") - .Attr("shared_dx: int = 0") - .Attr("alpha: float = 1.0") - .Attr("beta: float = 0.0") - .Attr("gated_dw: bool = false") - .Attr("gate_grad: bool = false") - .Attr("bench: int = 0") - .Attr("ngate: int >= 0") - .SetShapeFn(XpropShape) - .Doc(R"doc( -Multiply the matrix "a" by the blocksparse matrix "b". -)doc"); +REGISTER_OP("TritonBlocksparseMatmul") +.Input("x: T") +.Input("w: T") +.Input("lut: int64") +.Input("lut_dx: int64") +.Input("lut_dw: int64") +.Input("gate: ngate * float") +.Output("y: T") +.Output("temp: int32") +.Attr("T: {half, float, bfloat16}") +.Attr("blocks: int >=0") +.Attr("bsize: int") +.Attr("segments: int = 0") +.Attr("segments_dx: int = 0") +.Attr("locks: int = 0") +.Attr("locks_dx: int = 0") +.Attr("axis: int = 1") +.Attr("C: int >=0") +.Attr("K: int >=0") +.Attr("shared: int = 0") +.Attr("shared_dx: int = 0") +.Attr("alpha: float = 1.0") +.Attr("beta: float = 0.0") +.Attr("gated_dw: bool = false") +.Attr("gate_grad: bool = false") +.Attr("bench: int = 0") +.Attr("ngate: int >= 0") +.SetShapeFn(XpropShape) +.Doc(R"doc( + Multiply the matrix "a" by the blocksparse matrix "b". + )doc"); typedef struct bsmm_params { - const int* Lut; - const float* Gate; - int* Lock; - //float4* Scratch; - int blocks; - int bsize; - int segments; - int locks; - int C; - int K; - int N; - int shared; - int pcount; - uint blk_a; - uint blk_A; - uint blk_b; - uint blk_B; - float alpha; - float beta; - CUstream stream; + const int* Lut; + const float* Gate; + int* Lock; + int blocks; + int bsize; + int segments; + int locks; + int C; + int K; + int N; + int shared; + int pcount; + uint blk_a; + uint blk_A; + uint blk_b; + uint blk_B; + float alpha; + float beta; + CUstream stream; } bsmm_params; class BlocksparseMatmulOp : public OpKernel { - public: +public: explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments)); OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks )); @@ -147,6 +130,51 @@ class BlocksparseMatmulOp : public OpKernel { } void Compute(OpKernelContext* context){ + // get device/stream + GPUDevice device = context->eigen_device(); + triton::driver::cu_stream sstream(device.stream(), false); + triton::driver::context* ctx = sstream.context(); + triton::driver::stream* stream = &sstream; + // get inputs + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + const Tensor& lut = context->input(2); + // allocate c + TensorShape shape_c; + int N = 1; + int rank_a = a.dims(); + for (int i = 0; i < rank_a; i++) + if (i != axis_) { + shape_c.AddDim(a.dim_size(i)); + N *= a.dim_size(i); + } + else + shape_c.AddDim(params_.K); + Tensor* c = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c)); + // grid and block + int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127; + if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){ + blkN = 64; + gridN = (N + 63)/64; + } + // allocate locks + Tensor* locks; + TensorShape shape_l; + if (params_.locks > 0) + shape_l.AddDim(gridN * params_.locks * 2); + OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks)); + // initialize default compute device + triton::runtime::jit jit(ctx); + // matrix multiplication parameters + triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); + triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); +// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat().data(), false); + triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat().data(), false); + // blocksparse matmul + triton::dnn::blocksparse::dot dot(N, params_.K, params_.C); + dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING); } private: @@ -157,4 +185,4 @@ private: char bench_string_[256]; }; -REGISTER_KERNEL_BUILDER(Name("BlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); +REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); diff --git a/examples/python/tensorflow/conv.cpp b/examples/python/tensorflow/conv.cpp index ebd60ac6d..4ed457021 100644 --- a/examples/python/tensorflow/conv.cpp +++ b/examples/python/tensorflow/conv.cpp @@ -5,7 +5,6 @@ #include "triton/driver/stream.h" #include "triton/runtime/jit.h" #include "triton/tools/bench.hpp" -#include "triton/dnn/gemm.h" #include "triton/dnn/conv.h" #define EIGEN_USE_GPU diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 84f67664c..368ef8be3 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -5,7 +5,7 @@ #include "triton/driver/stream.h" #include "triton/runtime/jit.h" #include "triton/tools/bench.hpp" -#include "triton/dnn/gemm.h" +#include "triton/dnn/dot.h" #define EIGEN_USE_GPU #include "tensorflow/core/framework/op.h" diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 1a9ebbe59..2a3973fbb 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -19,10 +19,10 @@ using namespace tensorflow; using GPUDevice = Eigen::GpuDevice; -template +template class ShiftConvOp : public OpKernel { public: - explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) { + explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::NCHW) { context->GetAttr("shift_h", &h_shift_h_); context->GetAttr("shift_w", &h_shift_w_); context->GetAttr("stride_h", &stride_h_); @@ -32,13 +32,13 @@ public: } void ExtractShapes(const Tensor &x, int64_t &C, int64_t &H, int64_t &W, int64_t &B) { - if(layout_ == triton::dnn::shift::CHWN){ + if(layout_ == triton::dnn::CHWN){ C = x.dim_size(0); H = x.dim_size(1); W = x.dim_size(2); B = x.dim_size(3); } - else if(layout_ == triton::dnn::shift::NCHW){ + else if(layout_ == triton::dnn::NCHW){ B = x.dim_size(0); C = x.dim_size(1); H = x.dim_size(2); @@ -52,7 +52,7 @@ public: void FillShapes(OpKernelContext* context, int64_t &C, int64_t &H, int64_t &W, int64_t &B, int64_t &F, const Tensor& tf_a, const Tensor& tf_b) { - if(OP == triton::dnn::shift::WGRAD) { + if(OP == triton::dnn::WGRAD) { int64_t Ha, Wa, Ba; int64_t Hb, Wb, Bb; ExtractShapes(tf_a, F, Ha, Wa, Ba); @@ -68,19 +68,19 @@ public: // shapes for a int64_t Ca; ExtractShapes(tf_a, Ca, H, W, B); - if(OP == triton::dnn::shift::BPROP){ + if(OP == triton::dnn::BPROP){ H *= stride_h_; W *= stride_w_; } // shapes for b int64_t Cb = tf_b.dim_size(0); F = tf_b.dim_size(1); - if(OP == triton::dnn::shift::BPROP) + if(OP == triton::dnn::BPROP) std::swap(Cb, F); // checks OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels")); C = Ca; - if(OP == triton::dnn::shift::BPROP) + if(OP == triton::dnn::BPROP) std::swap(C, F); } } @@ -122,7 +122,7 @@ public: triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); - shift.enqueue(stream, {&da, &db, &dc}, false); + shift.enqueue(stream, {&da, &db, &dc}, triton::dnn::PARTIAL_TUNING); } private: @@ -132,10 +132,10 @@ private: int stride_w_; int R_; int S_; - triton::dnn::shift::layout_t layout_; + triton::dnn::layout_t layout_; }; -REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConv") .Input("a: float16") .Input("b: float16") @@ -145,7 +145,7 @@ REGISTER_OP("ShiftConv") .Attr("stride_w: int") .Output("c: float16"); -REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConvDx") .Input("a: float16") .Input("b: float16") @@ -155,7 +155,7 @@ REGISTER_OP("ShiftConvDx") .Attr("stride_w: int") .Output("c: float16"); -REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp); +REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConvDw") .Input("a: float16") .Input("b: float16") diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h new file mode 100644 index 000000000..fbd388937 --- /dev/null +++ b/include/triton/dnn/blocksparse/dot.h @@ -0,0 +1,42 @@ +#include "triton/driver/stream.h" +#include "triton/driver/kernel.h" +#include "triton/dnn/base.h" +#include + +namespace triton{ +namespace dnn{ +namespace blocksparse{ + + +class dot: public base { +private: + void enqueue_impl(driver::stream *stream, driver::kernel *kernel, + std::vector args, + triton::runtime::launch_information info); + // number of flops + virtual size_t num_flops() const; + // comparison for maps + virtual bool operator<(const base& other) const; + // default parameters + virtual std::vector search_space() const; + virtual params_t heuristics() const; + +public: + // constructor + dot(int32_t M, int32_t N, int32_t K); + // triton-c source + virtual void triton_c_src(std::ostream &os) const; + // clone + virtual base* clone() const; + +private: + std::string ab_ty_; + std::string c_ty_; + int32_t M_; + int32_t N_; + int32_t K_; +}; + +} +} +} diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/dot.h similarity index 100% rename from include/triton/dnn/gemm.h rename to include/triton/dnn/dot.h diff --git a/include/triton/dnn/heuristics.h b/include/triton/dnn/heuristics.h index e2efe6df2..ab8af7d32 100644 --- a/include/triton/dnn/heuristics.h +++ b/include/triton/dnn/heuristics.h @@ -13,79 +13,95 @@ typedef std::tuple size_key_t; static const std::map> params = { /* NN */ {trans_key_t(false, false), std::map{ - {size_key_t(16, 16), {4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, - {size_key_t(16, 32), {2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1}}, + {size_key_t(16, 16), {2, 8, 16, 4, 16, 2, 2, 1, 1, 16, 32, 8, 4, 1}}, + {size_key_t(16, 32), {4, 4, 16, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, {size_key_t(16, 64), {4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, - {size_key_t(16, 128), {4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1}}, - {size_key_t(32, 16), {4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, - {size_key_t(32, 32), {4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1}}, - {size_key_t(32, 64), {8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, - {size_key_t(32, 128), {8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1}}, - {size_key_t(64, 16), {8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1}}, + {size_key_t(16, 128), {2, 8, 16, 8, 128, 2, 2, 1, 1, 16, 32, 4, 8, 1}}, + {size_key_t(32, 16), {8, 4, 32, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, + {size_key_t(32, 32), {4, 8, 32, 4, 32, 2, 2, 1, 1, 8, 32, 8, 4, 1}}, + {size_key_t(32, 64), {8, 4, 32, 4, 64, 2, 2, 1, 1, 4, 32, 8, 4, 1}}, + {size_key_t(32, 128), {8, 4, 32, 32, 128, 2, 2, 2, 2, 16, 32, 4, 4, 1}}, + {size_key_t(32, 256), {4, 8, 32, 32, 256, 2, 2, 1, 4, 32, 32, 4, 8, 1}}, + {size_key_t(64, 16), {8, 8, 64, 8, 16, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, {size_key_t(64, 32), {8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1}}, - {size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, - {size_key_t(64, 128), {16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1}}, - {size_key_t(128, 16), {8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1}}, - {size_key_t(128, 32), {8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, - {size_key_t(128, 64), {8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1}}, - {size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1}} + {size_key_t(64, 64), {8, 8, 64, 16, 64, 2, 2, 1, 2, 8, 32, 4, 8, 1}}, + {size_key_t(64, 128), {16, 4, 64, 32, 128, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 16), {8, 8, 128, 16, 16, 2, 2, 2, 1, 8, 32, 4, 8, 1}}, + {size_key_t(128, 32), {32, 4, 128, 16, 32, 2, 2, 2, 1, 2, 32, 4, 8, 1}}, + {size_key_t(128, 64), {16, 8, 128, 16, 64, 2, 2, 2, 2, 8, 32, 8, 4, 1}}, + {size_key_t(128, 128), {8, 8, 128, 32, 128, 2, 2, 2, 2, 16, 32, 4, 8, 1}}, + {size_key_t(256, 16), {32, 8, 256, 16, 16, 2, 2, 4, 1, 4, 32, 8, 4, 1}}, + {size_key_t(256, 32), {32, 8, 256, 16, 32, 2, 2, 4, 1, 4, 32, 8, 4, 1}}, + {size_key_t(256, 64), {16, 8, 256, 32, 64, 2, 2, 4, 1, 8, 32, 4, 8, 1}} }}, /* NT */ {trans_key_t(false, true), std::map{ - {size_key_t(16, 16), {4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1}}, - {size_key_t(16, 32), {4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, - {size_key_t(16, 64), {4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1}}, - {size_key_t(16, 128), {4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1}}, - {size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, - {size_key_t(32, 32), {4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1}}, - {size_key_t(32, 64), {16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1}}, - {size_key_t(32, 128), {4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1}}, - {size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, - {size_key_t(64, 32), {8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1}}, - {size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1}}, - {size_key_t(64, 128), {8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1}}, - {size_key_t(128, 16), {8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1}}, - {size_key_t(128, 32), {16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1}}, - {size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1}}, - {size_key_t(128, 128), {8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1}} - }}, + {size_key_t(16, 16), {2, 4, 16, 2, 8, 16, 2, 2, 1, 1, 16, 32, 16, 1}}, + {size_key_t(16, 32), {4, 4, 16, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}}, + {size_key_t(16, 64), {2, 4, 16, 2, 8, 64, 2, 2, 1, 1, 16, 32, 16, 1}}, + {size_key_t(16, 128), {2, 8, 16, 8, 8, 128, 2, 2, 1, 1, 16, 32, 4, 1}}, + {size_key_t(32, 16), {8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, + {size_key_t(32, 32), {4, 8, 32, 8, 4, 32, 2, 2, 1, 1, 8, 32, 4, 1}}, + {size_key_t(32, 64), {16, 4, 64, 16, 4, 64, 2, 2, 4, 1, 8, 32, 8, 1}}, + {size_key_t(32, 128), {4, 8, 32, 16, 4, 128, 2, 2, 1, 2, 16, 32, 4, 1}}, + {size_key_t(32, 256), {4, 8, 32, 64, 4, 256, 2, 2, 1, 4, 32, 32, 2, 1}}, + {size_key_t(64, 16), {8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1}}, + {size_key_t(64, 32), {16, 4, 64, 4, 4, 32, 2, 2, 1, 1, 2, 32, 8, 1}}, + {size_key_t(64, 64), {8, 8, 64, 8, 8, 64, 2, 2, 2, 1, 8, 32, 8, 1}}, + {size_key_t(64, 128), {4, 4, 64, 8, 8, 128, 2, 2, 1, 4, 32, 32, 16, 1}}, + {size_key_t(64, 256), {8, 8, 64, 8, 8, 256, 2, 2, 1, 4, 16, 32, 16, 1}}, + {size_key_t(128, 16), {16, 8, 128, 2, 8, 16, 2, 2, 1, 1, 2, 32, 16, 1}}, + {size_key_t(128, 32), {32, 4, 128, 4, 8, 32, 2, 2, 2, 1, 2, 32, 16, 1}}, + {size_key_t(128, 64), {8, 8, 128, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}}, + {size_key_t(128, 128), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1}}, + {size_key_t(256, 16), {32, 4, 256, 4, 4, 16, 2, 2, 4, 1, 4, 32, 32, 1}}, + {size_key_t(256, 32), {16, 8, 256, 8, 4, 32, 2, 2, 4, 1, 8, 32, 16, 1}}, + {size_key_t(256, 64), {8, 8, 256, 8, 8, 64, 2, 2, 4, 1, 16, 32, 16, 1}} + }}, /* TN */ {trans_key_t(true, false), std::map{ - {size_key_t(16, 16), {8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1}}, - {size_key_t(16, 32), {4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(16, 64), {4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, - {size_key_t(16, 128), {16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1}}, - {size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, - {size_key_t(32, 32), {8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(32, 64), {8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(32, 128), {32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1}}, - {size_key_t(64, 16), {8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(64, 32), {8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(64, 64), {16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(64, 128), {32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1}}, - {size_key_t(128, 16), {16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 32), {32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, - {size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(16, 16), {4, 16, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(16, 32), {8, 16, 8, 32, 2, 2, 1, 1, 4, 4, 32, 4, 4, 1}}, + {size_key_t(16, 64), {4, 16, 8, 64, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(16, 128), {4, 16, 8, 128, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(32, 16), {4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(32, 32), {4, 32, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(32, 64), {4, 32, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(32, 128), {8, 32, 8, 128, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(32, 256), {32, 32, 32, 256, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}}, + {size_key_t(64, 16), {4, 64, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(64, 32), {4, 64, 4, 32, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(64, 64), {8, 64, 16, 64, 2, 2, 2, 1, 8, 4, 32, 4, 8, 1}}, + {size_key_t(64, 128), {16, 64, 32, 128, 2, 2, 1, 4, 8, 4, 32, 4, 8, 1}}, + {size_key_t(128, 16), {8, 128, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 32), {16, 128, 16, 32, 2, 2, 4, 1, 8, 4, 32, 8, 4, 1}}, + {size_key_t(128, 64), {32, 128, 32, 64, 2, 2, 2, 2, 4, 8, 32, 4, 8, 1}}, + {size_key_t(128, 128), {32, 128, 32, 128, 2, 2, 1, 4, 4, 8, 32, 4, 8, 1}}, + {size_key_t(256, 16), {16, 256, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1}}, + {size_key_t(256, 32), {16, 256, 32, 32, 2, 2, 4, 1, 8, 4, 32, 4, 8, 1}}, }}, /* TT */ {trans_key_t(true, true), std::map{ - {size_key_t(16, 16), {4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, - {size_key_t(16, 32), {8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, - {size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}}, - {size_key_t(16, 128), {16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1}}, - {size_key_t(32, 16), {4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, - {size_key_t(32, 32), {8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, - {size_key_t(32, 64), {16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1}}, - {size_key_t(32, 128), {32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, - {size_key_t(64, 16), {8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}}, - {size_key_t(64, 32), {8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, - {size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}}, - {size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1}}, - {size_key_t(128, 16), {16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1}}, - {size_key_t(128, 32), {32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}}, - {size_key_t(128, 64), {32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1}}, - {size_key_t(128, 128), {32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}} + {size_key_t(16, 16), {8, 16, 4, 4, 16, 2, 2, 1, 1, 4, 8, 32, 8, 1}}, + {size_key_t(16, 32), {8, 16, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}}, + {size_key_t(16, 64), {16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1}}, + {size_key_t(16, 128), {16, 16, 8, 8, 128, 2, 2, 1, 1, 2, 4, 32, 4, 1}}, + {size_key_t(32, 16), {4, 32, 4, 4, 16, 2, 2, 1, 1, 8, 4, 32, 8, 1}}, + {size_key_t(32, 32), {8, 32, 8, 4, 32, 2, 2, 1, 1, 4, 8, 32, 4, 1}}, + {size_key_t(32, 64), {64, 128, 8, 4, 64, 2, 2, 4, 1, 2, 8, 32, 16, 1}}, + {size_key_t(32, 128), {16, 32, 32, 4, 128, 2, 2, 1, 2, 4, 8, 32, 2, 1}}, + {size_key_t(32, 256), {32, 32, 32, 4, 256, 2, 2, 1, 4, 4, 8, 32, 4, 1}}, + {size_key_t(64, 16), {4, 64, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1}}, + {size_key_t(64, 32), {4, 64, 8, 4, 32, 2, 2, 1, 1, 8, 4, 32, 4, 1}}, + {size_key_t(64, 64), {16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1}}, + {size_key_t(64, 128), {32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 4, 32, 16, 1}}, + {size_key_t(64, 256), {64, 64, 8, 8, 256, 2, 2, 1, 4, 2, 8, 32, 16}}, + {size_key_t(128, 16), {8, 128, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1}}, + {size_key_t(128, 32), {16, 128, 8, 4, 32, 2, 2, 4, 1, 8, 4, 32, 16, 1}}, + {size_key_t(128, 64), {32, 128, 8, 8, 64, 2, 2, 4, 1, 4, 8, 32, 16, 1}}, + {size_key_t(128, 128), {32, 128, 16, 8, 128, 2, 2, 2, 2, 4, 8, 32, 8, 1}}, + {size_key_t(256, 16), {32, 256, 4, 4, 16, 2, 2, 4, 1, 4, 8, 32, 32, 1}}, + {size_key_t(256, 32), {32, 256, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1}} }} }; diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 8f0f1ef73..de84d1788 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -72,6 +72,7 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); +// optimize_dce.run(module); } void target_dependent(ir::module &module) { diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index b4e40a3f2..5ab9c55f8 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1190,9 +1190,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Value *ptr = pointers->get_value(idx); ConstantInt *cst = nullptr; if(GetElementPtrInst *gep = dyn_cast(ptr)) - if(gep->getNumIndices() == 1){ + if(gep->getNumIndices() == 1) cst = dyn_cast(gep->idx_begin()); - } ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); packets[id] = builder.CreateLoad(ptr); @@ -1202,7 +1201,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & result->for_each([&](indices_t idx){ unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; -// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); + result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); }); } // element-wise diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 1da6240dd..820db29b3 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -247,14 +247,14 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 8, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp; } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 8, 8)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 8, 8)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index d7ffc11d2..72e0d340e 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -59,8 +59,8 @@ void base::enqueue(driver::stream *stream, std::vector args, a jit->add_module(name_.c_str(), src.c_str(), best.params); } else { - params_t params = heuristics(); -// params_t params = jit->get_valid(name_.c_str(), src.c_str()); +// params_t params = heuristics(); + params_t params = jit->get_valid(name_.c_str(), src.c_str()); jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp new file mode 100644 index 000000000..9ddb2514b --- /dev/null +++ b/lib/dnn/blocksparse/dot.cpp @@ -0,0 +1,109 @@ +#include "triton/dnn/blocksparse/dot.h" + +namespace triton{ +namespace dnn{ +namespace blocksparse{ + + +size_t dot::num_flops() const { + +} + +bool dot::operator <(const base& other) const { + auto *y = dynamic_cast(&other); + if(!y) + return true; + return std::tie(M_, N_, K_) + < std::tie(y->M_, y->N_, y->K_); +} + +std::vector dot::search_space() const { + +} + +params_t dot::heuristics() const { + +} + +base * dot::clone() const { + return new dot(*this); +} + +dot::dot(int32_t M, int32_t N, int32_t K): + base("bsdot"), M_(M), N_(N), K_(K) { + ab_ty_ = "fp32"; + c_ty_ = "fp32"; +} + +void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, + std::vector args, runtime::launch_information info) { + driver::buffer *a = args[0]; + driver::buffer *b = args[1]; + driver::buffer *c = args[2]; + driver::buffer *lut = args[3]; + int32_t lda = M_; + int32_t ldc = M_; + kernel->setArg(0, a); + kernel->setArg(1, b); + kernel->setArg(2, c); + kernel->setArg(3, lda); + kernel->setArg(4, ldc); + kernel->setArg(5, lut); + int32_t TM = info.globals["TM"]; + int32_t TN = info.globals["TN"]; + size_t grid_0 = (M_ + TM - 1) / TM; + size_t grid_1 = (N_ + TN - 1) / TN; + stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); + stream->synchronize(); +} + +void dot::triton_c_src(std::ostream &os) const { + std::string result = + + R"( + const tunable int32 TM = {64, 128}; + const tunable int32 TN = {32}; + const tunable int32 TK = {32}; + + void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, + restrict read_only align(16) )" + ab_ty_ + R"( *B, + fp32* C, + int32 lda, int32 ldc, + int32* lut_base){ + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + fp32 c[TM, TN] = 0; + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + int32 rxa[TM] = ridx * TM + (0 ... TM); + int32 ryb[TN] = 0 ... TN; + int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda; + int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK; + int32 *header = lut_base + ridy * 4; + int32 offset = *(header + 0); + int32 K = *(header + 1); + int32 h2 = *(header + 2); + int32 h3 = *(header + 3); + int32 *lut = lut_base + offset*2; + for(int32 k = K; k > 0; k = k - 1){ + int32 ak = *(lut + 0); + int32 bk = *(lut + 1); + fp32* pa[TM, TK] = A + offa + ak * TK * lda; + fp32* pb[TK, TN] = B + offb + bk * TK * TN; + fp32 a[TM, TK] = *pa; + fp32 b[TK, TN] = *pb;; + c = dot(a, b, c); + lut = lut + 2; + } + int32 rxc[TM] = ridx * TM + (0 ... TM); + int32 ryc[TN] = ridy * TN + (0 ... TN); + fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; + *pc = c; + })"; + + os << result; +} + +} +} +} diff --git a/lib/dnn/gemm.cpp b/lib/dnn/dot.cpp similarity index 97% rename from lib/dnn/gemm.cpp rename to lib/dnn/dot.cpp index 897a26402..114ec7450 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/dot.cpp @@ -1,6 +1,6 @@ #include "triton/driver/stream.h" #include "triton/driver/kernel.h" -#include "triton/dnn/gemm.h" +#include "triton/dnn/dot.h" #include "triton/dnn/heuristics.h" #include @@ -101,8 +101,8 @@ void dot::triton_c_src(std::ostream &os) const { std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TM = {16, 32, 64, 128, 256}; +const tunable int32 TN = {16, 32, 64, 128, 256}; const tunable int32 TK = {32}; const tunable int32 GZ = {1}; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index fd45ea805..58c62dd46 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -467,7 +467,7 @@ if(op_ == WGRAD){ pa = pa_base + offxa[:, newaxis];)"; } result += R"( - @checka a = *pa;)"; + a = checka ? *pa : 0;)"; /* Increment B pointers */ if(op_ == WGRAD){ @@ -488,7 +488,7 @@ if(op_ == BPROP){ pb = pb + TK;)"; } result += R"( - @checkb b = *pb; + b = checkb ? *pb : 0; } int32 rxc[TM] = ridx*TM + (0 ... TM); int32 ryc[TN] = ridy*TN + (0 ... TN);)"; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index f39f6c397..6fa727406 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,12 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools - ThreadPool pool(nthreads); +// ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function - pool.enqueue(f,values); +// pool.enqueue(f,values); + f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return; @@ -50,7 +51,7 @@ void parallel_loop_nest(std::vector const & ranges, } i = D - 1; // Short sleep so that the thread pool doesn't grow too big - std::this_thread::sleep_for(std::chrono::microseconds(1)); +// std::this_thread::sleep_for(std::chrono::microseconds(1)); } } @@ -211,9 +212,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben best.perf = perf; best.params = params; } - for(size_t i = 0; i < params.size(); i++) - std::cout << ((i==0)?"":", ") << params[i] << std::flush; - std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; +// for(size_t i = 0; i < params.size(); i++) +// std::cout << ((i==0)?"":", ") << params[i] << std::flush; +// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; } };