[examples] added conv2d op in tensorflow
This commit is contained in:
@@ -13,13 +13,13 @@ int main() {
|
|||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||||
// initialization
|
// initialization
|
||||||
int32_t B = 64, NF = 64;
|
int32_t B = 16, NF = 128;
|
||||||
int32_t D = 1, H = 8, W = 8;
|
int32_t D = 1, H = 16, W = 16;
|
||||||
int32_t NC = 3, T = 1, R = 3, S = 3;
|
int32_t NC = 64, T = 1, R = 3, S = 3;
|
||||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||||
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, triton::dnn::conv::FPROP, 0);
|
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
|
||||||
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
|
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
|
||||||
// convolution configuration
|
// convolution configuration
|
||||||
std::vector<float> hc(configuration.c_size());
|
std::vector<float> hc(configuration.c_size());
|
||||||
|
@@ -5,7 +5,7 @@ if(${TensorFlow_FOUND})
|
|||||||
include_directories("${CUDA_HOME}/include")
|
include_directories("${CUDA_HOME}/include")
|
||||||
link_directories(${TF_LIB})
|
link_directories(${TF_LIB})
|
||||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||||
add_library(tf_blocksparse SHARED dot.cpp dense_conv)
|
add_library(tf_blocksparse SHARED dot.cpp conv2d.cpp)
|
||||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/run.py
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
${CMAKE_CURRENT_BINARY_DIR}/run.py
|
||||||
|
@@ -20,21 +20,9 @@
|
|||||||
using namespace tensorflow;
|
using namespace tensorflow;
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
//torch::Tensor conv_common(
|
class Conv2dOp : public OpKernel {
|
||||||
// int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
public:
|
||||||
// int32_t T, int32_t R, int32_t S, int32_t NF,
|
explicit Conv2dOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
// int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
|
||||||
// int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
|
||||||
// triton::dnn::conv::type ty,
|
|
||||||
// torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
|
||||||
// bool autotune = false
|
|
||||||
// ) {
|
|
||||||
|
|
||||||
//}
|
|
||||||
|
|
||||||
class DenseConvOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit DenseConvOp(OpKernelConstruction* context) : OpKernel(context) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context){
|
void Compute(OpKernelContext* context){
|
||||||
@@ -64,15 +52,19 @@ class DenseConvOp : public OpKernel {
|
|||||||
bool has_bias = false;
|
bool has_bias = false;
|
||||||
|
|
||||||
// get conv configuration
|
// get conv configuration
|
||||||
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF,
|
triton::dnn::conv configuration(B, C,
|
||||||
|
D, H, W,
|
||||||
|
T, R, S,
|
||||||
|
NF,
|
||||||
stride_d, stride_h, stride_w,
|
stride_d, stride_h, stride_w,
|
||||||
pad_d, pad_h, pad_w,
|
pad_d, pad_h, pad_w,
|
||||||
1, 1, 1,
|
1, 1, 1,
|
||||||
|
"fp16", "fp16",
|
||||||
triton::dnn::conv::FPROP, has_bias);
|
triton::dnn::conv::FPROP, has_bias);
|
||||||
|
|
||||||
// Bind memory
|
// Bind memory
|
||||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<float>().data(), false);
|
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<float>().data(), false);
|
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||||
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
// triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||||
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
// triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||||
triton::driver::buffer* bias = nullptr;
|
triton::driver::buffer* bias = nullptr;
|
||||||
@@ -106,12 +98,16 @@ class DenseConvOp : public OpKernel {
|
|||||||
|
|
||||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||||
jit.add_module("conv", src.c_str(), best.params);
|
jit.add_module("conv", src.c_str(), best.params);
|
||||||
|
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
|
||||||
|
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||||
|
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||||
|
std::cout << benchmark(kernel, info) << std::endl;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("DenseConv").Device(DEVICE_GPU), DenseConvOp);
|
REGISTER_KERNEL_BUILDER(Name("Conv2d").Device(DEVICE_GPU), Conv2dOp);
|
||||||
REGISTER_OP("DenseConv")
|
REGISTER_OP("Conv2d")
|
||||||
.Input("a: float32")
|
.Input("a: float16")
|
||||||
.Input("b: float32")
|
.Input("b: float16")
|
||||||
.Output("c: float32")
|
.Output("c: float32")
|
||||||
;
|
;
|
@@ -78,7 +78,7 @@ class DotOp : public OpKernel {
|
|||||||
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
std::cout << benchmark(kernel, info) << std::endl;;
|
std::cout << benchmark(kernel, info) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -32,7 +32,7 @@ def run_conv():
|
|||||||
R, S, NF = 3, 3, 32
|
R, S, NF = 3, 3, 32
|
||||||
a = tf.placeholder(tf.float32, shape=[BS, C, H, W])
|
a = tf.placeholder(tf.float32, shape=[BS, C, H, W])
|
||||||
b = tf.placeholder(tf.float32, shape=[C, R, S, NF])
|
b = tf.placeholder(tf.float32, shape=[C, R, S, NF])
|
||||||
c = module.dense_conv(a, b)
|
c = module.conv2d(a, b)
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(BS, C, H, W)
|
ha = np.random.rand(BS, C, H, W)
|
||||||
hb = np.random.rand(C, R, S, NF)
|
hb = np.random.rand(C, R, S, NF)
|
||||||
|
@@ -31,6 +31,7 @@ public:
|
|||||||
int stride_d, int stride_h, int stride_w,
|
int stride_d, int stride_h, int stride_w,
|
||||||
int pad_d, int pad_h, int pad_w,
|
int pad_d, int pad_h, int pad_w,
|
||||||
int upsample_d, int upsample_h, int upsample_w,
|
int upsample_d, int upsample_h, int upsample_w,
|
||||||
|
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||||
type ty = FPROP, bool bias = false);
|
type ty = FPROP, bool bias = false);
|
||||||
|
|
||||||
// accessors
|
// accessors
|
||||||
@@ -126,7 +127,10 @@ private:
|
|||||||
bool is_a_deltas_cst;
|
bool is_a_deltas_cst;
|
||||||
bool is_b_deltas_cst_;
|
bool is_b_deltas_cst_;
|
||||||
bool is_mask_cst_;
|
bool is_mask_cst_;
|
||||||
// type
|
// data type
|
||||||
|
std::string a_ty_;
|
||||||
|
std::string b_ty_;
|
||||||
|
// conv type
|
||||||
type ty_;
|
type ty_;
|
||||||
bool bias_;
|
bool bias_;
|
||||||
bool b_trans_;
|
bool b_trans_;
|
||||||
|
@@ -347,7 +347,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
}
|
}
|
||||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||||
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||||
return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0));
|
return (Instruction*)offset;
|
||||||
}
|
}
|
||||||
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
||||||
BasicBlock *current = builder.GetInsertBlock();
|
BasicBlock *current = builder.GetInsertBlock();
|
||||||
|
@@ -233,13 +233,13 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
|
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
|
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
*params_.at(i).at("nts.d0") = *tmp1;
|
*params_.at(i).at("nts.d0") = *tmp1;
|
||||||
*params_.at(i).at("nts.d1") = *tmp2;
|
*params_.at(i).at("nts.d1") = *tmp2;
|
||||||
}
|
}
|
||||||
|
@@ -21,11 +21,13 @@ conv::conv(int B, int NC,
|
|||||||
int stride_d, int stride_h, int stride_w,
|
int stride_d, int stride_h, int stride_w,
|
||||||
int pad_d, int pad_h, int pad_w,
|
int pad_d, int pad_h, int pad_w,
|
||||||
int upsample_d, int upsample_h, int upsample_w,
|
int upsample_d, int upsample_h, int upsample_w,
|
||||||
|
std::string a_ty, std::string b_ty,
|
||||||
type ty, bool bias)
|
type ty, bool bias)
|
||||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||||
|
a_ty_(a_ty), b_ty_(b_ty),
|
||||||
ty_(ty), bias_(bias)
|
ty_(ty), bias_(bias)
|
||||||
{
|
{
|
||||||
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
CD_ = (AD_*upsample_d_ - BD_ + 1 + 2*pad_d_ + stride_d_ - 1)/stride_d_;
|
||||||
@@ -281,8 +283,8 @@ void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
|
|||||||
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
|
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
|
||||||
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
|
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
|
||||||
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
|
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
|
||||||
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4);
|
d_locks_ = triton::driver::buffer::create(stream->context(), max_grid_0_*max_grid_1_*4*2);
|
||||||
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4);
|
((triton::driver::cu_buffer*)d_locks_)->set_zero(stream, max_grid_0_*max_grid_1_*4*2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv::set_arg(driver::kernel *kernel,
|
void conv::set_arg(driver::kernel *kernel,
|
||||||
@@ -336,8 +338,8 @@ void conv::set_arg(driver::kernel *kernel,
|
|||||||
kernel->setArg(39, (int32_t)0);
|
kernel->setArg(39, (int32_t)0);
|
||||||
kernel->setArg(40, (int32_t)0);
|
kernel->setArg(40, (int32_t)0);
|
||||||
kernel->setArg(41, d_locks_);
|
kernel->setArg(41, d_locks_);
|
||||||
kernel->setArg(42, 0);
|
kernel->setArg(42, max_grid_0_);
|
||||||
kernel->setArg(43, 0);
|
kernel->setArg(43, max_grid_1_);
|
||||||
size_t idx = 44;
|
size_t idx = 44;
|
||||||
if(!is_a_deltas_cst)
|
if(!is_a_deltas_cst)
|
||||||
kernel->setArg(idx++, d_a_deltas_);
|
kernel->setArg(idx++, d_a_deltas_);
|
||||||
@@ -358,8 +360,6 @@ void conv::enqueue(driver::stream *stream, driver::kernel *kernel,
|
|||||||
grid[0] /= upsample_h_*upsample_w_;
|
grid[0] /= upsample_h_*upsample_w_;
|
||||||
kernel->setArg(11, CH_/upsample_h_);
|
kernel->setArg(11, CH_/upsample_h_);
|
||||||
kernel->setArg(12, CW_/upsample_w_);
|
kernel->setArg(12, CW_/upsample_w_);
|
||||||
kernel->setArg(42, (int32_t)grid[0]);
|
|
||||||
kernel->setArg(43, (int32_t)grid[1]);
|
|
||||||
|
|
||||||
// initialize to zero if necessary
|
// initialize to zero if necessary
|
||||||
bool init_zero = false;
|
bool init_zero = false;
|
||||||
@@ -526,7 +526,7 @@ void conv::src(std::ostream &os){
|
|||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64};
|
const tunable int32 TM = {16, 32, 64};
|
||||||
const tunable int32 TN = {16, 32, 64};
|
const tunable int32 TN = {16, 32, 64};
|
||||||
const tunable int32 TK = {8};
|
const tunable int32 TK = {16};
|
||||||
const tunable int32 GZ = {1};
|
const tunable int32 GZ = {1};
|
||||||
)";
|
)";
|
||||||
if(is_a_deltas_cst)
|
if(is_a_deltas_cst)
|
||||||
@@ -537,8 +537,8 @@ if(is_mask_cst_)
|
|||||||
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||||
os << R"(
|
os << R"(
|
||||||
|
|
||||||
void conv(read_only restrict fp32 *a,
|
void conv(read_only restrict )" << a_ty_ << R"( *a,
|
||||||
read_only restrict fp32 *b,
|
read_only restrict )" << b_ty_ << R"( *b,
|
||||||
fp32 *c,
|
fp32 *c,
|
||||||
fp32 *bias,
|
fp32 *bias,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
@@ -592,7 +592,7 @@ if(!is_mask_cst_)
|
|||||||
rar = )" + upar + R"( rar;
|
rar = )" + upar + R"( rar;
|
||||||
ras = )" + upas + R"( ras;
|
ras = )" + upas + R"( ras;
|
||||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||||
if(b_lut_){
|
if(b_lut_){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
||||||
@@ -611,7 +611,7 @@ os << R"(
|
|||||||
int32 rb1[TK] = rkb)" + ldb0 + ";";
|
int32 rb1[TK] = rkb)" + ldb0 + ";";
|
||||||
}
|
}
|
||||||
os << R"(
|
os << R"(
|
||||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
||||||
int32 offda[TK] = rka % ldlut;
|
int32 offda[TK] = rka % ldlut;
|
||||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda;
|
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda;
|
||||||
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||||
@@ -628,8 +628,8 @@ os << R"(
|
|||||||
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||||
int1 checkb0[TN] = rb0 < N;
|
int1 checkb0[TN] = rb0 < N;
|
||||||
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
||||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
|
||||||
fp32 b)" + BS + R"( = checkb ? *pb : 0;
|
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
|
||||||
int32 rkamin[TK] = rka - offk + TK;
|
int32 rkamin[TK] = rka - offk + TK;
|
||||||
for(int32 k = K; k > 0; k = k - TK){
|
for(int32 k = K; k > 0; k = k - TK){
|
||||||
C = dot(a, )" + useb + R"(, C);
|
C = dot(a, )" + useb + R"(, C);
|
||||||
@@ -672,8 +672,8 @@ if(b_lut_){
|
|||||||
int32 ridx = get_range_id(0);
|
int32 ridx = get_range_id(0);
|
||||||
int32 ridy = get_range_id(1);
|
int32 ridy = get_range_id(1);
|
||||||
int32 *plock = locks + ridx + ridy*grid0;
|
int32 *plock = locks + ridx + ridy*grid0;
|
||||||
|
while(__atomic_cas(plock, 0, 1) == 1);
|
||||||
int32 *pcount = plock + grid0*grid1;
|
int32 *pcount = plock + grid0*grid1;
|
||||||
while(__atomic_cas(plock, 0, 1));
|
|
||||||
int32 count = *pcount;
|
int32 count = *pcount;
|
||||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||||
if(count == 0) {)";
|
if(count == 0) {)";
|
||||||
@@ -691,7 +691,7 @@ if(b_lut_){
|
|||||||
@checkc *pc = C + *pc;
|
@checkc *pc = C + *pc;
|
||||||
*pcount = countp1;
|
*pcount = countp1;
|
||||||
}
|
}
|
||||||
__atomic_cas(plock, 1, 0);
|
*plock = 0;
|
||||||
})";
|
})";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user