[examples] added conv2d op in tensorflow

This commit is contained in:
Philippe Tillet
2019-06-26 18:50:53 -07:00
parent f1a8972267
commit 6300ec5080
9 changed files with 49 additions and 49 deletions

View File

@@ -13,13 +13,13 @@ int main() {
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 64, NF = 64;
int32_t D = 1, H = 8, W = 8;
int32_t NC = 3, T = 1, R = 3, S = 3;
int32_t B = 16, NF = 128;
int32_t D = 1, H = 16, W = 16;
int32_t NC = 64, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, triton::dnn::conv::FPROP, 0);
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
// convolution configuration
std::vector<float> hc(configuration.c_size());

View File

@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
add_library(tf_blocksparse SHARED dot.cpp dense_conv)
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp)
target_link_libraries(tf_blocksparse tensorflow_framework triton)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
${CMAKE_CURRENT_BINARY_DIR}/run.py

View File

@@ -20,21 +20,9 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
//torch::Tensor conv_common(
// int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
// int32_t T, int32_t R, int32_t S, int32_t NF,
// int32_t stride_d, int32_t stride_h, int32_t stride_w,
// int32_t pad_d, int32_t pad_h, int32_t pad_w,
// triton::dnn::conv::type ty,
// torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
// bool autotune = false
// ) {
//}
class DenseConvOp : public OpKernel {
public:
explicit DenseConvOp(OpKernelConstruction* context) : OpKernel(context) {
class Conv2dOp : public OpKernel {
public:
explicit Conv2dOp(OpKernelConstruction* context) : OpKernel(context) {
}
void Compute(OpKernelContext* context){
@@ -64,15 +52,19 @@ class DenseConvOp : public OpKernel {
bool has_bias = false;
// get conv configuration
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF,
triton::dnn::conv configuration(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<float>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<float>().data(), false);
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
triton::driver::buffer* bias = nullptr;
@@ -106,12 +98,16 @@ class DenseConvOp : public OpKernel {
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << benchmark(kernel, info) << std::endl;
}
};
REGISTER_KERNEL_BUILDER(Name("DenseConv").Device(DEVICE_GPU), DenseConvOp);
REGISTER_OP("DenseConv")
.Input("a: float32")
.Input("b: float32")
REGISTER_KERNEL_BUILDER(Name("Conv2d").Device(DEVICE_GPU), Conv2dOp);
REGISTER_OP("Conv2d")
.Input("a: float16")
.Input("b: float16")
.Output("c: float32")
;

View File

@@ -78,7 +78,7 @@ class DotOp : public OpKernel {
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;;
std::cout << benchmark(kernel, info) << std::endl;
}
private:

View File

@@ -32,7 +32,7 @@ def run_conv():
R, S, NF = 3, 3, 32
a = tf.placeholder(tf.float32, shape=[BS, C, H, W])
b = tf.placeholder(tf.float32, shape=[C, R, S, NF])
c = module.dense_conv(a, b)
c = module.conv2d(a, b)
# Reference
ha = np.random.rand(BS, C, H, W)
hb = np.random.rand(C, R, S, NF)

View File

@@ -31,6 +31,7 @@ public:
int stride_d, int stride_h, int stride_w,
int pad_d, int pad_h, int pad_w,
int upsample_d, int upsample_h, int upsample_w,
std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false);
// accessors
@@ -126,7 +127,10 @@ private:
bool is_a_deltas_cst;
bool is_b_deltas_cst_;
bool is_mask_cst_;
// type
// data type
std::string a_ty_;
std::string b_ty_;
// conv type
type ty_;
bool bias_;
bool b_trans_;

View File

@@ -347,7 +347,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
return (Instruction*)offset;
}
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
BasicBlock *current = builder.GetInsertBlock();

View File

@@ -233,13 +233,13 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}

View File

@@ -21,11 +21,13 @@ conv::conv(int B, int NC,
int stride_d, int stride_h, int stride_w,
int pad_d, int pad_h, int pad_w,
int upsample_d, int upsample_h, int upsample_w,
std::string a_ty, std::string b_ty,
type ty, bool bias)
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias)
{
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
@@ -281,8 +283,8 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4);
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4);
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
}
void conv::set_arg(driver::kernel *kernel,
@@ -336,8 +338,8 @@ void conv::set_arg(driver::kernel *kernel,
kernel->setArg(39, (int32_t)0);
kernel->setArg(40, (int32_t)0);
kernel->setArg(41, d_locks_);
kernel->setArg(42, 0);
kernel->setArg(43, 0);
kernel->setArg(42, max_grid_0_);
kernel->setArg(43, max_grid_1_);
size_t idx = 44;
if(!is_a_deltas_cst)
kernel->setArg(idx++, d_a_deltas_);
@@ -358,8 +360,6 @@ void conv::enqueue(driver::stream *stream, driver::kernel *kernel,
grid[0] /= upsample_h_*upsample_w_;
kernel->setArg(11, CH_/upsample_h_);
kernel->setArg(12, CW_/upsample_w_);
kernel->setArg(42, (int32_t)grid[0]);
kernel->setArg(43, (int32_t)grid[1]);
// initialize to zero if necessary
bool init_zero = false;
@@ -526,7 +526,7 @@ void conv::src(std::ostream &os){
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
const tunable int32 TK = {16};
const tunable int32 GZ = {1};
)";
if(is_a_deltas_cst)
@@ -537,8 +537,8 @@ if(is_mask_cst_)
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
os << R"(
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
void conv(read_only restrict )" << a_ty_ << R"( *a,
read_only restrict )" << b_ty_ << R"( *b,
fp32 *c,
fp32 *bias,
int32 M, int32 N, int32 K,
@@ -592,7 +592,7 @@ if(!is_mask_cst_)
rar = )" + upar + R"( rar;
ras = )" + upas + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(b_lut_){
os << R"(
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
@@ -611,7 +611,7 @@ os << R"(
int32 rb1[TK] = rkb)" + ldb0 + ";";
}
os << R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
int32 offda[TK] = rka % ldlut;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda;
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
@@ -628,8 +628,8 @@ os << R"(
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
int1 checkb0[TN] = rb0 < N;
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b)" + BS + R"( = checkb ? *pb : 0;
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
int32 rkamin[TK] = rka - offk + TK;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
@@ -672,8 +672,8 @@ if(b_lut_){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1) == 1);
int32 *pcount = plock + grid0*grid1;
while(__atomic_cas(plock, 0, 1));
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
if(count == 0) {)";
@@ -691,7 +691,7 @@ if(b_lut_){
@checkc *pc = C + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
*plock = 0;
})";
}