[dnn/shift]: added support for fp16
This commit is contained in:
@@ -10,11 +10,11 @@
|
||||
|
||||
int main() {
|
||||
typedef float NumericT;
|
||||
std::string numeric_t_str = "fp32";
|
||||
std::string numeric_t_str = "fp16";
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
auto op = triton::dnn::shift::BPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
@@ -35,6 +35,15 @@ int main() {
|
||||
numeric_t_str, numeric_t_str,
|
||||
op, false, triton::dnn::shift::NCHW);
|
||||
// host buffers
|
||||
size_t a_size = B*C*H*W;
|
||||
size_t b_size = C*F;
|
||||
size_t c_size = B*F*H*W;
|
||||
if(op == triton::dnn::shift::BPROP)
|
||||
std::swap(a_size, c_size);
|
||||
if(op == triton::dnn::shift::WGRAD){
|
||||
std::swap(b_size, c_size);
|
||||
std::swap(a_size, b_size);
|
||||
}
|
||||
std::vector<NumericT> ha(B*C*H*W);
|
||||
std::vector<NumericT> hb(C*F);
|
||||
std::vector<float> hc(B*F*H*W);
|
||||
|
@@ -45,11 +45,20 @@ torch::Tensor shift_common(
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// Data-type
|
||||
std::string dtype;
|
||||
at::ScalarType type = torcha.scalar_type();
|
||||
switch(type){
|
||||
case at::ScalarType::Double: dtype = "fp64"; break;
|
||||
case at::ScalarType::Float: dtype = "fp32"; break;
|
||||
case at::ScalarType::Half: dtype = "fp16"; break;
|
||||
default: AT_ERROR("unknown data-type for shift-conv");
|
||||
}
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
shift_h, shift_w, dtype, dtype,
|
||||
ty, has_bias, layout);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
@@ -61,7 +70,9 @@ torch::Tensor shift_common(
|
||||
std::vector<long int> c_shapes;
|
||||
for(auto x: _c_shapes)
|
||||
c_shapes.push_back(x);
|
||||
torch::Tensor torchc = torch::empty(c_shapes).cuda();
|
||||
torch::Tensor torchc = torch::empty(c_shapes, type).cuda();
|
||||
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
|
@@ -123,8 +123,6 @@ class ShiftConvFunction(torch.autograd.Function):
|
||||
dw = torch.ops.triton.shift_conv_dw(dy.contiguous(), input, bias, width[0], width[1], stride[0], stride[1], shift_h, shift_w)
|
||||
if ctx.needs_input_grad[2]:
|
||||
dbias = torch.sum(dy, (1, 2, 3))
|
||||
#print('dx', ctx.needs_input_grad[0], np.isnan(dx.cpu().numpy()).any())
|
||||
#print('dw', ctx.needs_input_grad[1], np.isnan(dw.cpu().numpy()).any())
|
||||
return dx, dw, dbias, None, None, None, None
|
||||
|
||||
|
||||
|
@@ -58,29 +58,32 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 16, 16, 2, 2
|
||||
R, S, F = 3, 3, 32
|
||||
B, C, H, W = 1, 16, 4, 4
|
||||
R, S, F = 3, 3, 16
|
||||
stride_h, stride_w = 2, 2
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[B, C, H, W])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
a = tf.placeholder(tf.float16, shape=[B, C, H, W])
|
||||
b = tf.placeholder(tf.float16, shape=[C, F])
|
||||
hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_h = np.zeros(C, dtype=np.int32)
|
||||
#hshift_w = np.zeros(C, dtype=np.int32)
|
||||
c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# feed values
|
||||
ha = np.random.rand(B, C, H, W)
|
||||
hb = np.random.rand(C, F)
|
||||
#ha = np.ones((B, C, H, W), dtype=np.float32)
|
||||
#hb = np.ones((C, F), dtype=np.float32)
|
||||
ha = np.random.rand(B, C, H, W)*0.1
|
||||
hb = np.random.rand(C, F)*0.1
|
||||
#ha = np.ones((B, C, H, W), dtype=np.float16)
|
||||
#hb = np.ones((C, F), dtype=np.float16)
|
||||
sess = tf.InteractiveSession()
|
||||
# test
|
||||
grads = tf.test.compute_gradient([a, b], [(B, C, H, W), (C, F)], c, (B, F, H//stride_h, W//stride_w),
|
||||
extra_feed_dict = {a: ha, b: hb})
|
||||
extra_feed_dict = {a: ha, b: hb}, delta=1e-2)
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(dw_t, dw_n)
|
||||
#import sys
|
||||
#np.set_printoptions(threshold=sys.maxsize)
|
||||
print(dx_t)
|
||||
print(dx_n)
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
|
@@ -106,7 +106,7 @@ public:
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h_data, shift_w_data,
|
||||
"fp32", "fp32", OP, has_bias, layout_);
|
||||
"fp16", "fp16", OP, has_bias, layout_);
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
@@ -119,9 +119,9 @@ public:
|
||||
if (out_shapes.num_elements() == 0)
|
||||
return;
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||
shift.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
@@ -137,31 +137,31 @@ private:
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConv").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::FPROP>);
|
||||
REGISTER_OP("ShiftConv")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
.Output("c: float16");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
|
||||
REGISTER_OP("ShiftConvDx")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
.Output("c: float16");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
|
||||
REGISTER_OP("ShiftConvDw")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
.Output("c: float16");
|
||||
|
||||
|
@@ -60,7 +60,7 @@ public:
|
||||
// clone
|
||||
virtual base* clone() const = 0;
|
||||
// enqueue
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, bool autotune = false);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
@@ -155,6 +155,7 @@ private:
|
||||
// data types
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
std::string c_ty_;
|
||||
// convolution type
|
||||
type op_;
|
||||
bool bias_;
|
||||
|
@@ -376,7 +376,15 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_operand(0));
|
||||
Value *val = value(ii->get_operand(1));
|
||||
Value *atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
||||
Value *atom_f_add;
|
||||
if(val->getType()->isFloatTy())
|
||||
atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
||||
else if(val->getType()->isHalfTy()){
|
||||
Type *fp16 = Type::getHalfTy(ctx);
|
||||
|
||||
FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
|
||||
atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
|
||||
}
|
||||
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||
return (Instruction*)res;
|
||||
}
|
||||
@@ -1110,6 +1118,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
vector_size = 1;
|
||||
// vector_size = result->axis(0).contiguous;
|
||||
std::map<unsigned, Value*> packets;
|
||||
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
|
||||
|
@@ -22,9 +22,8 @@ void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
bool autotune = true;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
|
@@ -22,7 +22,7 @@ shift::shift(int B, int C,
|
||||
F_(F),
|
||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
a_ty_(a_ty), b_ty_(b_ty), c_ty_(b_ty),
|
||||
op_(ty), bias_(bias),
|
||||
layout_(layout){
|
||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||
@@ -230,8 +230,10 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(26, CW_);
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(op_ == BPROP)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4);
|
||||
if(op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
||||
}
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
@@ -264,7 +266,7 @@ __constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"
|
||||
|
||||
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
fp32 *C,
|
||||
)" + c_ty_ + R"( *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 lda_b, int32 lda_w, int32 lda_h, int32 lda_c,
|
||||
@@ -278,7 +280,7 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;)";
|
||||
|
||||
@@ -304,7 +306,7 @@ if(op_ == FPROP){
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int32 d[TK] = *pd;
|
||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
||||
)";
|
||||
@@ -424,7 +426,7 @@ if(op_ == WGRAD){
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + "," + useb + R"(, c);
|
||||
acc = dot()" + usea + "," + useb + R"(, acc);
|
||||
int1 checka[)" + AS + R"(] = k > TK;
|
||||
int1 checkb[)" + BS + R"(] = k > TK;)";
|
||||
|
||||
@@ -564,7 +566,8 @@ if(op_ == WGRAD){
|
||||
int32 offxc[TM] = rxc;)";
|
||||
}
|
||||
result += R"("
|
||||
fp32* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
|
||||
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
@@ -581,7 +584,7 @@ if(op_ == BPROP){
|
||||
result += R"(
|
||||
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
pc = interior ? shift_pc : pc;
|
||||
@checkc __atomic_add(pc, c);
|
||||
)";
|
||||
@@ -593,6 +596,7 @@ else{
|
||||
result += R"(
|
||||
})";
|
||||
|
||||
// std::cout << result << std::endl;
|
||||
os << result;
|
||||
}
|
||||
|
||||
|
@@ -149,7 +149,6 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// std::cout << ranges.size() << std::endl;
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
tune_res_t best;
|
||||
|
Reference in New Issue
Block a user