[examples] added conv2d op in tensorflow
This commit is contained in:
@@ -13,13 +13,13 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
// initialization
|
||||
int32_t B = 64, NF = 64;
|
||||
int32_t D = 1, H = 8, W = 8;
|
||||
int32_t NC = 3, T = 1, R = 3, S = 3;
|
||||
int32_t B = 16, NF = 128;
|
||||
int32_t D = 1, H = 16, W = 16;
|
||||
int32_t NC = 64, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, triton::dnn::conv::FPROP, 0);
|
||||
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
|
||||
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(configuration.c_size());
|
||||
|
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TF_LIB})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||
add_library(tf_blocksparse SHARED dot.cpp dense_conv)
|
||||
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp)
|
||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||
|
@@ -20,21 +20,9 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
//torch::Tensor conv_common(
|
||||
// int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
// int32_t T, int32_t R, int32_t S, int32_t NF,
|
||||
// int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
// int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
// triton::dnn::conv::type ty,
|
||||
// torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
// bool autotune = false
|
||||
// ) {
|
||||
|
||||
//}
|
||||
|
||||
class DenseConvOp : public OpKernel {
|
||||
public:
|
||||
explicit DenseConvOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
class Conv2dOp : public OpKernel {
|
||||
public:
|
||||
explicit Conv2dOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
@@ -64,15 +52,19 @@ class DenseConvOp : public OpKernel {
|
||||
bool has_bias = false;
|
||||
|
||||
// get conv configuration
|
||||
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF,
|
||||
triton::dnn::conv configuration(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
@@ -106,12 +98,16 @@ class DenseConvOp : public OpKernel {
|
||||
|
||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), best.params);
|
||||
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DenseConv").Device(DEVICE_GPU), DenseConvOp);
|
||||
REGISTER_OP("DenseConv")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
REGISTER_KERNEL_BUILDER(Name("Conv2d").Device(DEVICE_GPU), Conv2dOp);
|
||||
REGISTER_OP("Conv2d")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Output("c: float32")
|
||||
;
|
@@ -78,7 +78,7 @@ class DotOp : public OpKernel {
|
||||
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << benchmark(kernel, info) << std::endl;;
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -32,7 +32,7 @@ def run_conv():
|
||||
R, S, NF = 3, 3, 32
|
||||
a = tf.placeholder(tf.float32, shape=[BS, C, H, W])
|
||||
b = tf.placeholder(tf.float32, shape=[C, R, S, NF])
|
||||
c = module.dense_conv(a, b)
|
||||
c = module.conv2d(a, b)
|
||||
# Reference
|
||||
ha = np.random.rand(BS, C, H, W)
|
||||
hb = np.random.rand(C, R, S, NF)
|
||||
|
@@ -31,6 +31,7 @@ public:
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
type ty = FPROP, bool bias = false);
|
||||
|
||||
// accessors
|
||||
@@ -126,7 +127,10 @@ private:
|
||||
bool is_a_deltas_cst;
|
||||
bool is_b_deltas_cst_;
|
||||
bool is_mask_cst_;
|
||||
// type
|
||||
// data type
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
// conv type
|
||||
type ty_;
|
||||
bool bias_;
|
||||
bool b_trans_;
|
||||
|
@@ -347,7 +347,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
}
|
||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
|
||||
return (Instruction*)offset;
|
||||
}
|
||||
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
|
@@ -233,13 +233,13 @@ void tune::run(ir::module &mod) {
|
||||
continue;
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -21,11 +21,13 @@ conv::conv(int B, int NC,
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias)
|
||||
{
|
||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||
@@ -281,8 +283,8 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
|
||||
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
|
||||
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
|
||||
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
|
||||
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4);
|
||||
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4);
|
||||
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
|
||||
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
|
||||
}
|
||||
|
||||
void conv::set_arg(driver::kernel *kernel,
|
||||
@@ -336,8 +338,8 @@ void conv::set_arg(driver::kernel *kernel,
|
||||
kernel->setArg(39, (int32_t)0);
|
||||
kernel->setArg(40, (int32_t)0);
|
||||
kernel->setArg(41, d_locks_);
|
||||
kernel->setArg(42, 0);
|
||||
kernel->setArg(43, 0);
|
||||
kernel->setArg(42, max_grid_0_);
|
||||
kernel->setArg(43, max_grid_1_);
|
||||
size_t idx = 44;
|
||||
if(!is_a_deltas_cst)
|
||||
kernel->setArg(idx++, d_a_deltas_);
|
||||
@@ -358,8 +360,6 @@ void conv::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
grid[0] /= upsample_h_*upsample_w_;
|
||||
kernel->setArg(11, CH_/upsample_h_);
|
||||
kernel->setArg(12, CW_/upsample_w_);
|
||||
kernel->setArg(42, (int32_t)grid[0]);
|
||||
kernel->setArg(43, (int32_t)grid[1]);
|
||||
|
||||
// initialize to zero if necessary
|
||||
bool init_zero = false;
|
||||
@@ -526,7 +526,7 @@ void conv::src(std::ostream &os){
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 TK = {16};
|
||||
const tunable int32 GZ = {1};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
@@ -537,8 +537,8 @@ if(is_mask_cst_)
|
||||
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
os << R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
void conv(read_only restrict )" << a_ty_ << R"( *a,
|
||||
read_only restrict )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
fp32 *bias,
|
||||
int32 M, int32 N, int32 K,
|
||||
@@ -592,7 +592,7 @@ if(!is_mask_cst_)
|
||||
rar = )" + upar + R"( rar;
|
||||
ras = )" + upas + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(b_lut_){
|
||||
os << R"(
|
||||
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
||||
@@ -611,7 +611,7 @@ os << R"(
|
||||
int32 rb1[TK] = rkb)" + ldb0 + ";";
|
||||
}
|
||||
os << R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
||||
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
||||
int32 offda[TK] = rka % ldlut;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda;
|
||||
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
@@ -628,8 +628,8 @@ os << R"(
|
||||
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
int1 checkb0[TN] = rb0 < N;
|
||||
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b)" + BS + R"( = checkb ? *pb : 0;
|
||||
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
|
||||
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
|
||||
int32 rkamin[TK] = rka - offk + TK;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
@@ -672,8 +672,8 @@ if(b_lut_){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
if(count == 0) {)";
|
||||
@@ -691,7 +691,7 @@ if(b_lut_){
|
||||
@checkc *pc = C + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
*plock = 0;
|
||||
})";
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user