diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index e906493ac..d5f2bba3d 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -13,13 +13,13 @@ int main() { triton::jit jit(context); triton::dnn::conv::type ty = triton::dnn::conv::FPROP; // initialization - int32_t B = 64, NF = 64; - int32_t D = 1, H = 8, W = 8; - int32_t NC = 3, T = 1, R = 3, S = 3; + int32_t B = 16, NF = 128; + int32_t D = 1, H = 16, W = 16; + int32_t NC = 64, T = 1, R = 3, S = 3; int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; - triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, triton::dnn::conv::FPROP, 0); + triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0); // triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty); // convolution configuration std::vector hc(configuration.c_size()); diff --git a/examples/python/tensorflow/CMakeLists.txt b/examples/python/tensorflow/CMakeLists.txt index c531c23b1..bfd54f6a6 100644 --- a/examples/python/tensorflow/CMakeLists.txt +++ b/examples/python/tensorflow/CMakeLists.txt @@ -5,7 +5,7 @@ if(${TensorFlow_FOUND}) include_directories("${CUDA_HOME}/include") link_directories(${TF_LIB}) add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}) - add_library(tf_blocksparse SHARED dot.cpp dense_conv) + add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp) target_link_libraries(tf_blocksparse tensorflow_framework triton) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py ${CMAKE_CURRENT_BINARY_DIR}/run.py diff --git a/examples/python/tensorflow/dense_conv.cpp b/examples/python/tensorflow/conv2d.cpp similarity index 78% rename from examples/python/tensorflow/dense_conv.cpp rename to examples/python/tensorflow/conv2d.cpp index 66e7bfdab..12b033f21 100644 --- a/examples/python/tensorflow/dense_conv.cpp +++ b/examples/python/tensorflow/conv2d.cpp @@ -20,21 +20,9 @@ using namespace tensorflow; using GPUDevice = Eigen::GpuDevice; -//torch::Tensor conv_common( -// int32_t B, int32_t C, int32_t D, int32_t H, int32_t W, -// int32_t T, int32_t R, int32_t S, int32_t NF, -// int32_t stride_d, int32_t stride_h, int32_t stride_w, -// int32_t pad_d, int32_t pad_h, int32_t pad_w, -// triton::dnn::conv::type ty, -// torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias, -// bool autotune = false -// ) { - -//} - -class DenseConvOp : public OpKernel { - public: - explicit DenseConvOp(OpKernelConstruction* context) : OpKernel(context) { +class Conv2dOp : public OpKernel { +public: + explicit Conv2dOp(OpKernelConstruction* context) : OpKernel(context) { } void Compute(OpKernelContext* context){ @@ -64,15 +52,19 @@ class DenseConvOp : public OpKernel { bool has_bias = false; // get conv configuration - triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, + triton::dnn::conv configuration(B, C, + D, H, W, + T, R, S, + NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, 1, 1, 1, + "fp16", "fp16", triton::dnn::conv::FPROP, has_bias); // Bind memory - triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat().data(), false); - triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat().data(), false); + triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat().data(), false); + triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat().data(), false); // triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false); // triton::driver::buffer* bias = has_bias ? &cubias : nullptr; triton::driver::buffer* bias = nullptr; @@ -106,12 +98,16 @@ class DenseConvOp : public OpKernel { triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark); jit.add_module("conv", src.c_str(), best.params); +// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1}); + triton::driver::kernel* kernel = jit.get_function("conv"); + triton::jit::launch_information info = jit.get_launch_info("conv"); + std::cout << benchmark(kernel, info) << std::endl; } }; -REGISTER_KERNEL_BUILDER(Name("DenseConv").Device(DEVICE_GPU), DenseConvOp); -REGISTER_OP("DenseConv") - .Input("a: float32") - .Input("b: float32") +REGISTER_KERNEL_BUILDER(Name("Conv2d").Device(DEVICE_GPU), Conv2dOp); +REGISTER_OP("Conv2d") + .Input("a: float16") + .Input("b: float16") .Output("c: float32") ; diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index bdaab5921..09b9f47b4 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -78,7 +78,7 @@ class DotOp : public OpKernel { jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); - std::cout << benchmark(kernel, info) << std::endl;; + std::cout << benchmark(kernel, info) << std::endl; } private: diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 9756ee340..dca626b1a 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -32,7 +32,7 @@ def run_conv(): R, S, NF = 3, 3, 32 a = tf.placeholder(tf.float32, shape=[BS, C, H, W]) b = tf.placeholder(tf.float32, shape=[C, R, S, NF]) - c = module.dense_conv(a, b) + c = module.conv2d(a, b) # Reference ha = np.random.rand(BS, C, H, W) hb = np.random.rand(C, R, S, NF) diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index a950c3304..6a590f201 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -31,6 +31,7 @@ public: int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, int upsample_d, int upsample_h, int upsample_w, + std::string a_ty = "fp32", std::string b_ty = "fp32", type ty = FPROP, bool bias = false); // accessors @@ -126,7 +127,10 @@ private: bool is_a_deltas_cst; bool is_b_deltas_cst_; bool is_mask_cst_; - // type + // data type + std::string a_ty_; + std::string b_ty_; + // conv type type ty_; bool bias_; bool b_trans_; diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index d394c1ec0..b066a963a 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -347,7 +347,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); - return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0)); + return (Instruction*)offset; } if(ir::atomic_cas_inst* ii = dynamic_cast(inst)){ BasicBlock *current = builder.GetInsertBlock(); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index e1d62f4cd..72267b23f 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -233,13 +233,13 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 4)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 4)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 2)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 2)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 2acbfed94..c67d132c8 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -21,11 +21,13 @@ conv::conv(int B, int NC, int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, int upsample_d, int upsample_h, int upsample_w, + std::string a_ty, std::string b_ty, type ty, bool bias) : NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w), pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), + a_ty_(a_ty), b_ty_(b_ty), ty_(ty), bias_(bias) { CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_; @@ -281,8 +283,8 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) { d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_); d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_); d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_); - d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4); - ((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4); + d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2); + ((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2); } void conv::set_arg(driver::kernel *kernel, @@ -336,8 +338,8 @@ void conv::set_arg(driver::kernel *kernel, kernel->setArg(39, (int32_t)0); kernel->setArg(40, (int32_t)0); kernel->setArg(41, d_locks_); - kernel->setArg(42, 0); - kernel->setArg(43, 0); + kernel->setArg(42, max_grid_0_); + kernel->setArg(43, max_grid_1_); size_t idx = 44; if(!is_a_deltas_cst) kernel->setArg(idx++, d_a_deltas_); @@ -358,8 +360,6 @@ void conv::enqueue(driver::stream *stream, driver::kernel *kernel, grid[0] /= upsample_h_*upsample_w_; kernel->setArg(11, CH_/upsample_h_); kernel->setArg(12, CW_/upsample_w_); - kernel->setArg(42, (int32_t)grid[0]); - kernel->setArg(43, (int32_t)grid[1]); // initialize to zero if necessary bool init_zero = false; @@ -526,7 +526,7 @@ void conv::src(std::ostream &os){ R"( const tunable int32 TM = {16, 32, 64}; const tunable int32 TN = {16, 32, 64}; -const tunable int32 TK = {8}; +const tunable int32 TK = {16}; const tunable int32 GZ = {1}; )"; if(is_a_deltas_cst) @@ -537,8 +537,8 @@ if(is_mask_cst_) os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; os << R"( - void conv(read_only restrict fp32 *a, - read_only restrict fp32 *b, + void conv(read_only restrict )" << a_ty_ << R"( *a, + read_only restrict )" << b_ty_ << R"( *b, fp32 *c, fp32 *bias, int32 M, int32 N, int32 K, @@ -592,7 +592,7 @@ if(!is_mask_cst_) rar = )" + upar + R"( rar; ras = )" + upas + R"( ras; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; + )" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; if(b_lut_){ os << R"( int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(; @@ -611,7 +611,7 @@ os << R"( int32 rb1[TK] = rkb)" + ldb0 + ";"; } os << R"( - fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k; + )" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k; int32 offda[TK] = rka % ldlut; )" + a_delta_mem + R"( int32* pincd[TK] = delta + offda; )" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w; @@ -628,8 +628,8 @@ os << R"( int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; int1 checkb0[TN] = rb0 < N; int1 checkb)" + BS + " = checkb0" + bcb0 + R"(; - fp32 a[TM, TK] = checka ? *pa : 0; - fp32 b)" + BS + R"( = checkb ? *pb : 0; + )" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0; + )" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0; int32 rkamin[TK] = rka - offk + TK; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, )" + useb + R"(, C); @@ -672,8 +672,8 @@ if(b_lut_){ int32 ridx = get_range_id(0); int32 ridy = get_range_id(1); int32 *plock = locks + ridx + ridy*grid0; + while(__atomic_cas(plock, 0, 1) == 1); int32 *pcount = plock + grid0*grid1; - while(__atomic_cas(plock, 0, 1)); int32 count = *pcount; int32 countp1 = select(count == GZ - 1, 0, count + 1); if(count == 0) {)"; @@ -691,7 +691,7 @@ if(b_lut_){ @checkc *pc = C + *pc; *pcount = countp1; } - __atomic_cas(plock, 1, 0); + *plock = 0; })"; }