From 3e7a3ed67a14b3768b7470919fcd832b25eb0ac0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 13 Jul 2019 21:05:34 -0700 Subject: [PATCH] [dnn/shift]: added support for fp16 --- examples/cpp/shift.cpp | 13 +++++++++++-- examples/python/pytorch/shift.cpp | 15 +++++++++++++-- examples/python/pytorch/triton.py | 2 -- examples/python/tensorflow/run.py | 23 +++++++++++++---------- examples/python/tensorflow/shift.cpp | 26 +++++++++++++------------- include/triton/dnn/base.h | 2 +- include/triton/dnn/shift.h | 1 + lib/codegen/selection.cpp | 11 ++++++++++- lib/dnn/base.cpp | 3 +-- lib/dnn/shift.cpp | 22 +++++++++++++--------- lib/runtime/jit.cpp | 1 - 11 files changed, 76 insertions(+), 43 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 4be4861cc..41c123fef 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -10,11 +10,11 @@ int main() { typedef float NumericT; - std::string numeric_t_str = "fp32"; + std::string numeric_t_str = "fp16"; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - auto op = triton::dnn::shift::FPROP; + auto op = triton::dnn::shift::BPROP; // initialization int32_t R = 3, S = 3; @@ -35,6 +35,15 @@ int main() { numeric_t_str, numeric_t_str, op, false, triton::dnn::shift::NCHW); // host buffers + size_t a_size = B*C*H*W; + size_t b_size = C*F; + size_t c_size = B*F*H*W; + if(op == triton::dnn::shift::BPROP) + std::swap(a_size, c_size); + if(op == triton::dnn::shift::WGRAD){ + std::swap(b_size, c_size); + std::swap(a_size, b_size); + } std::vector ha(B*C*H*W); std::vector hb(C*F); std::vector hc(B*F*H*W); diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index d650ca9e6..e3e968db6 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -45,11 +45,20 @@ torch::Tensor shift_common( CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream(); triton::driver::cu_stream stream(custream, false); triton::driver::context* ctx = stream.context(); + // Data-type + std::string dtype; + at::ScalarType type = torcha.scalar_type(); + switch(type){ + case at::ScalarType::Double: dtype = "fp64"; break; + case at::ScalarType::Float: dtype = "fp32"; break; + case at::ScalarType::Half: dtype = "fp16"; break; + default: AT_ERROR("unknown data-type for shift-conv"); + } // Get configuration bool has_bias = torchbias.storage().size() > 0; triton::dnn::shift shift(B, C, D, H, W, T, R, S, F, stride_h, stride_w, - shift_h, shift_w, "fp32", "fp32", + shift_h, shift_w, dtype, dtype, ty, has_bias, layout); // Bind memory triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); @@ -61,7 +70,9 @@ torch::Tensor shift_common( std::vector c_shapes; for(auto x: _c_shapes) c_shapes.push_back(x); - torch::Tensor torchc = torch::empty(c_shapes).cuda(); + torch::Tensor torchc = torch::empty(c_shapes, type).cuda(); + + triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); // Enqueue shift.enqueue(&stream, {&a, &b, &c}); diff --git a/examples/python/pytorch/triton.py b/examples/python/pytorch/triton.py index efeade389..2d78e58f7 100644 --- a/examples/python/pytorch/triton.py +++ b/examples/python/pytorch/triton.py @@ -123,8 +123,6 @@ class ShiftConvFunction(torch.autograd.Function): dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w) if ctx.needs_input_grad[2]: dbias = torch.sum(dy, (1, 2, 3)) - #print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any()) - #print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any()) return dx, dw, dbias, None, None, None, None diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 57850de9a..971ad2898 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -58,29 +58,32 @@ def blocksparse_matmul_grad(op, dy): return (dx, dw) def run_shift(): - B, C, H, W = 16, 16, 2, 2 - R, S, F = 3, 3, 32 + B, C, H, W = 1, 16, 4, 4 + R, S, F = 3, 3, 16 stride_h, stride_w = 2, 2 np.random.seed(2) - a = tf.placeholder(tf.float32, shape=[B, C, H, W]) - b = tf.placeholder(tf.float32, shape=[C, F]) + a = tf.placeholder(tf.float16, shape=[B, C, H, W]) + b = tf.placeholder(tf.float16, shape=[C, F]) hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) #hshift_h = np.zeros(C, dtype=np.int32) #hshift_w = np.zeros(C, dtype=np.int32) c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) # feed values - ha = np.random.rand(B, C, H, W) - hb = np.random.rand(C, F) - #ha = np.ones((B, C, H, W), dtype=np.float32) - #hb = np.ones((C, F), dtype=np.float32) + ha = np.random.rand(B, C, H, W)*0.1 + hb = np.random.rand(C, F)*0.1 + #ha = np.ones((B, C, H, W), dtype=np.float16) + #hb = np.ones((C, F), dtype=np.float16) sess = tf.InteractiveSession() # test grads = tf.test.compute_gradient([a, b], [(B, C, H, W), (C, F)], c, (B, F, H//stride_h, W//stride_w), - extra_feed_dict = {a: ha, b: hb}) + extra_feed_dict = {a: ha, b: hb}, delta=1e-2) dw_t, dw_n = grads[1] dx_t, dx_n = grads[0] - print(dw_t, dw_n) + #import sys + #np.set_printoptions(threshold=sys.maxsize) + print(dx_t) + print(dx_n) print(np.max(np.abs(dw_t - dw_n))) print(np.max(np.abs(dx_t - dx_n))) # Run diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index d9014795e..d844e9aa1 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -106,7 +106,7 @@ public: triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F, stride_h_, stride_w_, shift_h_data, shift_w_data, - "fp32", "fp32", OP, has_bias, layout_); + "fp16", "fp16", OP, has_bias, layout_); // shapes for c std::vector c_shapes; @@ -119,9 +119,9 @@ public: if (out_shapes.num_elements() == 0) return; // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); + triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat().data(), false); + triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat().data(), false); shift.enqueue(stream, {&da, &db, &dc}); } @@ -137,31 +137,31 @@ private: REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConv") - .Input("a: float32") - .Input("b: float32") + .Input("a: float16") + .Input("b: float16") .Attr("shift_h: tensor") .Attr("shift_w: tensor") .Attr("stride_h: int") .Attr("stride_w: int") - .Output("c: float32"); + .Output("c: float16"); REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConvDx") - .Input("a: float32") - .Input("b: float32") + .Input("a: float16") + .Input("b: float16") .Attr("shift_h: tensor") .Attr("shift_w: tensor") .Attr("stride_h: int") .Attr("stride_w: int") - .Output("c: float32"); + .Output("c: float16"); REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp); REGISTER_OP("ShiftConvDw") - .Input("a: float32") - .Input("b: float32") + .Input("a: float16") + .Input("b: float16") .Attr("shift_h: tensor") .Attr("shift_w: tensor") .Attr("stride_h: int") .Attr("stride_w: int") - .Output("c: float32"); + .Output("c: float16"); diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index e3c6ff9e1..7aeab2a14 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -60,7 +60,7 @@ public: // clone virtual base* clone() const = 0; // enqueue - void enqueue(driver::stream* stream, std::vector args); + void enqueue(driver::stream* stream, std::vector args, bool autotune = false); private: std::string name_; diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 57cb5ea0a..ec4ffc753 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -155,6 +155,7 @@ private: // data types std::string a_ty_; std::string b_ty_; + std::string c_ty_; // convolution type type op_; bool bias_; diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 3bd010ebc..7ca8fb6ee 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -376,7 +376,15 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ Value *ptr = value(ii->get_operand(0)); Value *val = value(ii->get_operand(1)); - Value *atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()}); + Value *atom_f_add; + if(val->getType()->isFloatTy()) + atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()}); + else if(val->getType()->isHalfTy()){ + Type *fp16 = Type::getHalfTy(ctx); + + FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false); + atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true); + } Value *res = builder.CreateCall(atom_f_add, {ptr, val}); return (Instruction*)res; } @@ -1110,6 +1118,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned vector_size = std::min(result->axis(0).contiguous, alignment); + vector_size = 1; // vector_size = result->axis(0).contiguous; std::map packets; distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 61ab85b60..b3bf6c05a 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -22,9 +22,8 @@ void base::set_ld(const std::vector& shapes, base::base(const std::string& name) : name_(name) { } -void base::enqueue(driver::stream *stream, std::vector args) { +void base::enqueue(driver::stream *stream, std::vector args, bool autotune) { static std::map, cmp_recompile> m_jit; - bool autotune = true; driver::context* ctx = stream->context(); triton::jit* jit; /* the current template has not already been compiled */ diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 6e209fef3..872189c89 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -22,7 +22,7 @@ shift::shift(int B, int C, F_(F), stride_d_(1), stride_h_(stride_h), stride_w_(stride_w), shift_h_(shift_h), shift_w_(shift_w), - a_ty_(a_ty), b_ty_(b_ty), + a_ty_(a_ty), b_ty_(b_ty), c_ty_(b_ty), op_(ty), bias_(bias), layout_(layout){ // std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; @@ -230,8 +230,10 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(26, CW_); unsigned TM = ranges[0], TN = ranges[1]; std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; - if(op_ == BPROP) - ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4); + if(op_ == BPROP){ + size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; + ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); + } stream->enqueue(kernel, grid, {nthreads, 1, 1}); } @@ -264,7 +266,7 @@ __constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R" void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, - fp32 *C, + )" + c_ty_ + R"( *C, int32 M, int32 N, int32 K, int32 stride_h, int32 stride_w, int32 lda_b, int32 lda_w, int32 lda_h, int32 lda_c, @@ -278,7 +280,7 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, int32 ryb[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; - fp32 c[TM, TN] = 0; + fp32 acc[TM, TN] = 0; int32 pad_h = BH / 2; int32 pad_w = BW / 2;)"; @@ -304,7 +306,7 @@ if(op_ == FPROP){ int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offa0[TM, TK] = offxa[:, newaxis]; __constant__ int32* pd[TK] = delta_a + rka; - multiple_of(4) int32 d[TK] = *pd; + int32 d[TK] = *pd; int32 offa_interior[TM, TK] = d[newaxis, :]; int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; )"; @@ -424,7 +426,7 @@ if(op_ == WGRAD){ )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; for(int32 k = K; k > 0; k = k - TK){ - c = dot()" + usea + "," + useb + R"(, c); + acc = dot()" + usea + "," + useb + R"(, acc); int1 checka[)" + AS + R"(] = k > TK; int1 checkb[)" + BS + R"(] = k > TK;)"; @@ -564,7 +566,8 @@ if(op_ == WGRAD){ int32 offxc[TM] = rxc;)"; } result += R"(" - fp32* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c; + )" + c_ty_ + R"( c[TM, TN] = acc; + )" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c; int1 checkc0[TM] = rxc < M; int1 checkc1[TN] = ryc < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; @@ -581,7 +584,7 @@ if(op_ == BPROP){ result += R"( int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; __constant__ int32* pd[TN] = delta_a + ryc; - fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; + )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; pc = interior ? shift_pc : pc; @checkc __atomic_add(pc, c); )"; @@ -593,6 +596,7 @@ else{ result += R"( })"; +// std::cout << result << std::endl; os << result; } diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 30547a19e..536ad44b0 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -149,7 +149,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben std::vector> ranges; for(ir::metaparameter *mp: mps) ranges.push_back(mp->get_space()); -// std::cout << ranges.size() << std::endl; // iterate over parameters unsigned i; tune_res_t best;