blabla
This commit is contained in:
@@ -65,55 +65,33 @@ def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilati
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NetReference(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NetReference, self).__init__()
|
super(Net, self).__init__()
|
||||||
#self.conv1 = ShiftConv2d(1, 32, 3, 2)
|
self.conv1 = ShiftConv2d(1, 32, 3, 1)
|
||||||
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
self.conv2 = ShiftConv2d(32, 128, 3, 1)
|
||||||
self.bn1 = nn.BatchNorm2d(32)
|
self.conv3 = ShiftConv2d(128, 128, 3, 2)
|
||||||
self.conv2 = triton.ShiftConv2d(32, 32, 3, 2)
|
self.bn1 = nn.BatchNorm2d(128)
|
||||||
#self.conv2 = ShiftConv2d(32, 32, 3, 2)
|
self.conv4 = ShiftConv2d(128, 256, 3, 2)
|
||||||
self.bn2 = nn.BatchNorm2d(32)
|
self.bn2 = nn.BatchNorm2d(256)
|
||||||
self.fc1 = nn.Linear(32*7*7, 500)
|
self.fc1 = nn.Linear(256*7*7, 500)
|
||||||
self.fc2 = nn.Linear(500, 10)
|
self.fc2 = nn.Linear(500, 10)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = F.relu(x)
|
x = F.relu(x)
|
||||||
x = self.conv2(x)
|
x = self.conv4(x)
|
||||||
x = self.bn2(x)
|
x = self.bn2(x)
|
||||||
x = F.relu(x)
|
x = F.relu(x)
|
||||||
x = x.view(-1, 32*7*7)
|
x = x.view(-1, 256*7*7)
|
||||||
x = F.relu(self.fc1(x))
|
x = F.relu(self.fc1(x))
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
return F.log_softmax(x, dim=1)
|
return F.log_softmax(x, dim=1)
|
||||||
|
|
||||||
class NetTriton(nn.Module):
|
Net = Net()
|
||||||
def __init__(self):
|
|
||||||
super(NetTriton, self).__init__()
|
|
||||||
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
|
||||||
self.bn1 = triton.BatchNorm2d(32)
|
|
||||||
self.conv2 = triton.ShiftConv2d(32, 64, 3, 2)
|
|
||||||
self.bn2 = triton.BatchNorm2d(64)
|
|
||||||
self.fc1 = nn.Linear(64*7*7, 500)
|
|
||||||
self.fc2 = nn.Linear(500, 10)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(1, 2, 3, 0).contiguous()
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = F.relu(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = self.bn2(x)
|
|
||||||
x = F.relu(x)
|
|
||||||
x = x.permute(3, 0, 1, 2).contiguous()
|
|
||||||
x = x.view(-1, 64*7*7)
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = self.fc2(x)
|
|
||||||
return F.log_softmax(x, dim=1)
|
|
||||||
|
|
||||||
Net = NetReference()
|
|
||||||
|
|
||||||
def train(args, model, device, train_loader, optimizer, epoch):
|
def train(args, model, device, train_loader, optimizer, epoch):
|
||||||
model.train()
|
model.train()
|
||||||
|
@@ -58,7 +58,7 @@ def blocksparse_matmul_grad(op, dy):
|
|||||||
return (dx, dw)
|
return (dx, dw)
|
||||||
|
|
||||||
def run_shift():
|
def run_shift():
|
||||||
B, C, H, W = 16, 16, 4, 4
|
B, C, H, W = 16, 16, 2, 2
|
||||||
R, S, F = 3, 3, 32
|
R, S, F = 3, 3, 32
|
||||||
stride_h, stride_w = 2, 2
|
stride_h, stride_w = 2, 2
|
||||||
np.random.seed(2)
|
np.random.seed(2)
|
||||||
|
@@ -62,7 +62,7 @@ public:
|
|||||||
shift(int B, int NC,
|
shift(int B, int NC,
|
||||||
int D, int H, int W,
|
int D, int H, int W,
|
||||||
int T, int R, int S, int NF,
|
int T, int R, int S, int NF,
|
||||||
int stride_h, int stride_w,
|
int stride_h, int stride_w,
|
||||||
const int32_t* shift_h, const int32_t* shift_w,
|
const int32_t* shift_h, const int32_t* shift_w,
|
||||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||||
type ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
type ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
||||||
@@ -145,6 +145,8 @@ private:
|
|||||||
// shift values
|
// shift values
|
||||||
const int32_t* shift_h_;
|
const int32_t* shift_h_;
|
||||||
const int32_t* shift_w_;
|
const int32_t* shift_w_;
|
||||||
|
bool shift_edge_h_;
|
||||||
|
bool shift_edge_w_;
|
||||||
// look-up tables
|
// look-up tables
|
||||||
std::vector<int32_t> h_delta_a;
|
std::vector<int32_t> h_delta_a;
|
||||||
std::vector<int32_t> h_delta_b;
|
std::vector<int32_t> h_delta_b;
|
||||||
@@ -154,7 +156,7 @@ private:
|
|||||||
std::string a_ty_;
|
std::string a_ty_;
|
||||||
std::string b_ty_;
|
std::string b_ty_;
|
||||||
// convolution type
|
// convolution type
|
||||||
type ty_;
|
type op_;
|
||||||
bool bias_;
|
bool bias_;
|
||||||
// transpose
|
// transpose
|
||||||
bool AT_;
|
bool AT_;
|
||||||
|
@@ -23,8 +23,9 @@ shift::shift(int B, int C,
|
|||||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||||
shift_h_(shift_h), shift_w_(shift_w),
|
shift_h_(shift_h), shift_w_(shift_w),
|
||||||
a_ty_(a_ty), b_ty_(b_ty),
|
a_ty_(a_ty), b_ty_(b_ty),
|
||||||
ty_(ty), bias_(bias),
|
op_(ty), bias_(bias),
|
||||||
layout_(layout){
|
layout_(layout){
|
||||||
|
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||||
// max number of channels
|
// max number of channels
|
||||||
TK_ = 16;
|
TK_ = 16;
|
||||||
MAX_C_ = 8192 + TK_;
|
MAX_C_ = 8192 + TK_;
|
||||||
@@ -51,6 +52,9 @@ shift::shift(int B, int C,
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("unsupported input layout");
|
throw std::runtime_error("unsupported input layout");
|
||||||
}
|
}
|
||||||
|
// Shift edge
|
||||||
|
shift_edge_h_ = (AH_ == stride_h_);
|
||||||
|
shift_edge_w_ = (AW_ == stride_w_);
|
||||||
// B memory strides: [C, F]
|
// B memory strides: [C, F]
|
||||||
ldb_n_ = 1;
|
ldb_n_ = 1;
|
||||||
ldb_h_ = 1;
|
ldb_h_ = 1;
|
||||||
@@ -88,7 +92,7 @@ shift::shift(int B, int C,
|
|||||||
if(layout_ == NCHW)
|
if(layout_ == NCHW)
|
||||||
shapes_c_ = {B, F, CH_, CW_};
|
shapes_c_ = {B, F, CH_, CW_};
|
||||||
// Weight gradient
|
// Weight gradient
|
||||||
if(ty_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
// b <-> c
|
// b <-> c
|
||||||
// b <-> a
|
// b <-> a
|
||||||
std::swap(ldb_n_, ldc_n_);
|
std::swap(ldb_n_, ldc_n_);
|
||||||
@@ -106,7 +110,7 @@ shift::shift(int B, int C,
|
|||||||
shapes_c_ = {C, F};
|
shapes_c_ = {C, F};
|
||||||
}
|
}
|
||||||
// Input gradient
|
// Input gradient
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
// a <-> c
|
// a <-> c
|
||||||
std::swap(lda_n_, ldc_n_);
|
std::swap(lda_n_, ldc_n_);
|
||||||
std::swap(lda_w_, ldc_w_);
|
std::swap(lda_w_, ldc_w_);
|
||||||
@@ -128,10 +132,12 @@ base* shift::clone() const {
|
|||||||
|
|
||||||
void shift::build_delta_a() {
|
void shift::build_delta_a() {
|
||||||
h_delta_a.resize(MAX_C_);
|
h_delta_a.resize(MAX_C_);
|
||||||
if(ty_ == FPROP){
|
auto shift_h = [&](int c) { return shift_edge_h_ ? std::max(0, shift_h_[c]) : shift_h_[c]; };
|
||||||
|
auto shift_w = [&](int c) { return shift_edge_w_ ? std::max(0, shift_w_[c]) : shift_w_[c]; };
|
||||||
|
if(op_ == FPROP){
|
||||||
// compute offset
|
// compute offset
|
||||||
auto offset = [&](unsigned c) {
|
auto offset = [&](unsigned c) {
|
||||||
return c*lda_c_ + shift_h_[c]*lda_h_ + shift_w_[c]*lda_w_;
|
return c*lda_c_ + shift_h(c)*lda_h_ + shift_w(c)*lda_w_;
|
||||||
};
|
};
|
||||||
// populate look-up table
|
// populate look-up table
|
||||||
for(unsigned c = 0; c < TK_; c++)
|
for(unsigned c = 0; c < TK_; c++)
|
||||||
@@ -139,14 +145,14 @@ void shift::build_delta_a() {
|
|||||||
for(unsigned c = 0; c < C_; c++)
|
for(unsigned c = 0; c < C_; c++)
|
||||||
h_delta_a[TK_ + c] = offset(c + TK_) - offset(c);
|
h_delta_a[TK_ + c] = offset(c + TK_) - offset(c);
|
||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
for(unsigned c = 0; c < C_; c++){
|
for(unsigned c = 0; c < C_; c++){
|
||||||
h_delta_a[c] = shift_h_[c]*ldc_h_ + shift_w_[c]*ldc_w_;
|
h_delta_a[c] = shift_h(c)*ldc_h_ + shift_w(c)*ldc_w_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
for(unsigned c = 0; c < C_; c++)
|
for(unsigned c = 0; c < C_; c++)
|
||||||
h_delta_a[c] = shift_h_[c]*ldb_h_ + shift_w_[c]*ldb_w_;
|
h_delta_a[c] = shift_h(c)*ldb_h_ + shift_w(c)*ldb_w_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,10 +173,22 @@ bool shift::operator <(const base& other) const{
|
|||||||
auto *y = dynamic_cast<const shift*>(&other);
|
auto *y = dynamic_cast<const shift*>(&other);
|
||||||
if(!y)
|
if(!y)
|
||||||
return true;
|
return true;
|
||||||
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
|
return std::tie(B_, C_, F_,
|
||||||
shift_h_, shift_w_, ty_, bias_)
|
AD_, AH_, AW_,
|
||||||
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
|
BD_, BH_, BW_,
|
||||||
y->shift_h_, y->shift_w_, y->ty_, y->bias_);
|
CD_, CH_, CW_,
|
||||||
|
shift_h_, shift_w_,
|
||||||
|
stride_h_, stride_w_,
|
||||||
|
layout_, op_,
|
||||||
|
bias_)
|
||||||
|
< std::tie(y->B_, y->C_, y->F_,
|
||||||
|
y->AD_, y->AH_, y->AW_,
|
||||||
|
y->BD_, y->BH_, y->BW_,
|
||||||
|
y->CD_, y->CH_, y->CW_,
|
||||||
|
y->shift_h_, y->shift_w_,
|
||||||
|
y->stride_h_, y->stride_w_,
|
||||||
|
y->layout_, y->op_,
|
||||||
|
y->bias_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||||
@@ -212,7 +230,7 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|||||||
kernel->setArg(26, CW_);
|
kernel->setArg(26, CW_);
|
||||||
unsigned TM = ranges[0], TN = ranges[1];
|
unsigned TM = ranges[0], TN = ranges[1];
|
||||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||||
if(ty_ == BPROP)
|
if(op_ == BPROP)
|
||||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4);
|
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4);
|
||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
}
|
}
|
||||||
@@ -263,7 +281,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *A,
|
|||||||
int32 pad_w = BW / 2;)";
|
int32 pad_w = BW / 2;)";
|
||||||
|
|
||||||
/* A offsets */
|
/* A offsets */
|
||||||
if(ty_ == FPROP){
|
if(op_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 rawh[TM] = rxa / NB;
|
int32 rawh[TM] = rxa / NB;
|
||||||
int32 rab[TM] = rxa % NB;
|
int32 rab[TM] = rxa % NB;
|
||||||
@@ -274,13 +292,20 @@ if(ty_ == FPROP){
|
|||||||
__constant__ int32* pd[TK] = delta_a + rka;
|
__constant__ int32* pd[TK] = delta_a + rka;
|
||||||
multiple_of(4) int32 d[TK] = *pd;
|
multiple_of(4) int32 d[TK] = *pd;
|
||||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;\n)";
|
||||||
int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
if(shift_edge_h_)
|
||||||
int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
|
os << " int1 interiorh[TM] = 1;";
|
||||||
|
else
|
||||||
|
os << " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));";
|
||||||
|
if(shift_edge_w_)
|
||||||
|
os << " int1 interiorw[TM] = 1;";
|
||||||
|
else
|
||||||
|
os << " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
|
||||||
|
os << R"(
|
||||||
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 rawh[TM] = rxa / NB;
|
int32 rawh[TM] = rxa / NB;
|
||||||
int32 rab[TM] = rxa % NB;
|
int32 rab[TM] = rxa % NB;
|
||||||
@@ -290,12 +315,12 @@ if(ty_ == BPROP){
|
|||||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD && layout_ == CHWN){
|
if(op_ == WGRAD && layout_ == CHWN){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD && layout_ == NCHW){
|
if(op_ == WGRAD && layout_ == NCHW){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||||
int32 rawh[TK] = rka / NB;
|
int32 rawh[TK] = rka / NB;
|
||||||
@@ -307,17 +332,17 @@ if(ty_ == WGRAD && layout_ == NCHW){
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* B offsets */
|
/* B offsets */
|
||||||
if(ty_ == FPROP){
|
if(op_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offb0[TN, TK] = ryb[:, newaxis];
|
int32 offb0[TN, TK] = ryb[:, newaxis];
|
||||||
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
os << R"(
|
os << R"(
|
||||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||||
int32 d[TN] = *pd;
|
int32 d[TN] = *pd;
|
||||||
@@ -326,9 +351,16 @@ if(ty_ == WGRAD){
|
|||||||
int32 rbb[TK] = rkb % NB;
|
int32 rbb[TK] = rkb % NB;
|
||||||
int32 rbw[TK] = (rbwh % CW)*stride_w;
|
int32 rbw[TK] = (rbwh % CW)*stride_w;
|
||||||
int32 rbh[TK] = (rbwh / CW)*stride_h;
|
int32 rbh[TK] = (rbwh / CW)*stride_h;
|
||||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
|
||||||
int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
if(shift_edge_h_)
|
||||||
int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
os << " int1 interiorh[TK] = 1;\n";
|
||||||
|
else
|
||||||
|
os << " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||||
|
if(shift_edge_w_)
|
||||||
|
os << " int1 interiorw[TK] = 1;\n";
|
||||||
|
else
|
||||||
|
os << " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
|
||||||
|
os << R"(
|
||||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||||
int32 incb[TK, TN] = interior ? shift : 0;
|
int32 incb[TK, TN] = interior ? shift : 0;
|
||||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||||
@@ -349,7 +381,7 @@ if(ty_ == WGRAD){
|
|||||||
int1 checkb[)" << BS << R"(] = k > TK;)";
|
int1 checkb[)" << BS << R"(] = k > TK;)";
|
||||||
|
|
||||||
/* Increment A pointers */
|
/* Increment A pointers */
|
||||||
if(ty_ == FPROP){
|
if(op_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pd = pd + TK;
|
pd = pd + TK;
|
||||||
d = *pd;
|
d = *pd;
|
||||||
@@ -358,15 +390,15 @@ if(ty_ == FPROP){
|
|||||||
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
|
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
|
||||||
pa = pa + offa;)";
|
pa = pa + offa;)";
|
||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pa = pa + TK * lda_c;)";
|
pa = pa + TK * lda_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD && layout_ == CHWN){
|
if(op_ == WGRAD && layout_ == CHWN){
|
||||||
os << R"(
|
os << R"(
|
||||||
pa = pa + TK;)";
|
pa = pa + TK;)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD && layout_ == NCHW){
|
if(op_ == WGRAD && layout_ == NCHW){
|
||||||
os << R"(
|
os << R"(
|
||||||
rka = rka + TK;
|
rka = rka + TK;
|
||||||
rawh = rka / NB;
|
rawh = rka / NB;
|
||||||
@@ -380,25 +412,32 @@ if(ty_ == WGRAD && layout_ == NCHW){
|
|||||||
@checka a = *pa;)";
|
@checka a = *pa;)";
|
||||||
|
|
||||||
/* Increment B pointers */
|
/* Increment B pointers */
|
||||||
if(ty_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
os << R"(
|
os << R"(
|
||||||
rkb = rkb + TK;
|
rkb = rkb + TK;
|
||||||
rbwh = rkb / NB;
|
rbwh = rkb / NB;
|
||||||
rbb = rkb % NB;
|
rbb = rkb % NB;
|
||||||
rbw = (rbwh % CW)*stride_w;
|
rbw = (rbwh % CW)*stride_w;
|
||||||
rbh = (rbwh / CW)*stride_h;
|
rbh = (rbwh / CW)*stride_h;
|
||||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
|
||||||
interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
if(shift_edge_h_)
|
||||||
interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
os << " interiorh = 1;\n";
|
||||||
|
else
|
||||||
|
os << " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||||
|
if(shift_edge_w_)
|
||||||
|
os << " interiorw = 1;\n";
|
||||||
|
else
|
||||||
|
os << " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
|
||||||
|
os << R"(
|
||||||
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||||
incb = interior ? shift : 0;
|
incb = interior ? shift : 0;
|
||||||
pb = B + offb0 + offkb[:, newaxis] + incb;)";
|
pb = B + offb0 + offkb[:, newaxis] + incb;)";
|
||||||
}
|
}
|
||||||
if(ty_ == FPROP){
|
if(op_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pb = pb + TK * ldb_c;)";
|
pb = pb + TK * ldb_c;)";
|
||||||
}
|
}
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
pb = pb + TK;)";
|
pb = pb + TK;)";
|
||||||
}
|
}
|
||||||
@@ -409,7 +448,7 @@ if(ty_ == BPROP){
|
|||||||
int32 ryc[TN] = get_global_range[TN](1);)";
|
int32 ryc[TN] = get_global_range[TN](1);)";
|
||||||
|
|
||||||
/* C offsets */
|
/* C offsets */
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 rcwh[TM] = rxc / NB;
|
int32 rcwh[TM] = rxc / NB;
|
||||||
int32 rcb[TM] = rxc % NB;
|
int32 rcb[TM] = rxc % NB;
|
||||||
@@ -418,7 +457,7 @@ if(ty_ == BPROP){
|
|||||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
if(ty_ == FPROP){
|
if(op_ == FPROP){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 rcwh[TM] = rxc / NB;
|
int32 rcwh[TM] = rxc / NB;
|
||||||
int32 rcb[TM] = rxc % NB;
|
int32 rcb[TM] = rxc % NB;
|
||||||
@@ -427,7 +466,7 @@ if(ty_ == FPROP){
|
|||||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
if(ty_ == WGRAD){
|
if(op_ == WGRAD){
|
||||||
os << R"(
|
os << R"(
|
||||||
int32 offxc[TM] = rxc;
|
int32 offxc[TM] = rxc;
|
||||||
)";
|
)";
|
||||||
@@ -437,10 +476,17 @@ if(ty_ == WGRAD){
|
|||||||
int1 checkc0[TM] = rxc < M;
|
int1 checkc0[TM] = rxc < M;
|
||||||
int1 checkc1[TN] = ryc < N;
|
int1 checkc1[TN] = ryc < N;
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||||
if(ty_ == BPROP){
|
if(op_ == BPROP){
|
||||||
|
os << "\n";
|
||||||
|
if(shift_edge_h_)
|
||||||
|
os << " int1 interiorh[TM] = 1;\n";
|
||||||
|
else
|
||||||
|
os << " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
|
||||||
|
if(shift_edge_w_)
|
||||||
|
os << " int1 interiorw[TM] = 1;\n";
|
||||||
|
else
|
||||||
|
os << " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));\n";
|
||||||
os << R"(
|
os << R"(
|
||||||
int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
|
||||||
int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
|
||||||
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||||
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||||
|
Reference in New Issue
Block a user