This commit is contained in:
Philippe Tillet
2019-07-12 17:42:29 -07:00
parent f36a646ffc
commit c1c7062914
4 changed files with 108 additions and 82 deletions

View File

@@ -65,55 +65,33 @@ def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilati
) )
class NetReference(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(NetReference, self).__init__() super(Net, self).__init__()
#self.conv1 = ShiftConv2d(1, 32, 3, 2) self.conv1 = ShiftConv2d(1, 32, 3, 1)
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) self.conv2 = ShiftConv2d(32, 128, 3, 1)
self.bn1 = nn.BatchNorm2d(32) self.conv3 = ShiftConv2d(128, 128, 3, 2)
self.conv2 = triton.ShiftConv2d(32, 32, 3, 2) self.bn1 = nn.BatchNorm2d(128)
#self.conv2 = ShiftConv2d(32, 32, 3, 2) self.conv4 = ShiftConv2d(128, 256, 3, 2)
self.bn2 = nn.BatchNorm2d(32) self.bn2 = nn.BatchNorm2d(256)
self.fc1 = nn.Linear(32*7*7, 500) self.fc1 = nn.Linear(256*7*7, 500)
self.fc2 = nn.Linear(500, 10) self.fc2 = nn.Linear(500, 10)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.bn1(x) x = self.bn1(x)
x = F.relu(x) x = F.relu(x)
x = self.conv2(x) x = self.conv4(x)
x = self.bn2(x) x = self.bn2(x)
x = F.relu(x) x = F.relu(x)
x = x.view(-1, 32*7*7) x = x.view(-1, 256*7*7)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
class NetTriton(nn.Module): Net = Net()
def __init__(self):
super(NetTriton, self).__init__()
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
self.bn1 = triton.BatchNorm2d(32)
self.conv2 = triton.ShiftConv2d(32, 64, 3, 2)
self.bn2 = triton.BatchNorm2d(64)
self.fc1 = nn.Linear(64*7*7, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = x.permute(1, 2, 3, 0).contiguous()
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = x.permute(3, 0, 1, 2).contiguous()
x = x.view(-1, 64*7*7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Net = NetReference()
def train(args, model, device, train_loader, optimizer, epoch): def train(args, model, device, train_loader, optimizer, epoch):
model.train() model.train()

View File

@@ -58,7 +58,7 @@ def blocksparse_matmul_grad(op, dy):
return (dx, dw) return (dx, dw)
def run_shift(): def run_shift():
B, C, H, W = 16, 16, 4, 4 B, C, H, W = 16, 16, 2, 2
R, S, F = 3, 3, 32 R, S, F = 3, 3, 32
stride_h, stride_w = 2, 2 stride_h, stride_w = 2, 2
np.random.seed(2) np.random.seed(2)

View File

@@ -62,7 +62,7 @@ public:
shift(int B, int NC, shift(int B, int NC,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int NF, int T, int R, int S, int NF,
int stride_h, int stride_w, int stride_h, int stride_w,
const int32_t* shift_h, const int32_t* shift_w, const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32", std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false, layout_t layout = CHWN); type ty = FPROP, bool bias = false, layout_t layout = CHWN);
@@ -145,6 +145,8 @@ private:
// shift values // shift values
const int32_t* shift_h_; const int32_t* shift_h_;
const int32_t* shift_w_; const int32_t* shift_w_;
bool shift_edge_h_;
bool shift_edge_w_;
// look-up tables // look-up tables
std::vector<int32_t> h_delta_a; std::vector<int32_t> h_delta_a;
std::vector<int32_t> h_delta_b; std::vector<int32_t> h_delta_b;
@@ -154,7 +156,7 @@ private:
std::string a_ty_; std::string a_ty_;
std::string b_ty_; std::string b_ty_;
// convolution type // convolution type
type ty_; type op_;
bool bias_; bool bias_;
// transpose // transpose
bool AT_; bool AT_;

View File

@@ -23,8 +23,9 @@ shift::shift(int B, int C,
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w), stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
shift_h_(shift_h), shift_w_(shift_w), shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty), a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias), op_(ty), bias_(bias),
layout_(layout){ layout_(layout){
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
// max number of channels // max number of channels
TK_ = 16; TK_ = 16;
MAX_C_ = 8192 + TK_; MAX_C_ = 8192 + TK_;
@@ -51,6 +52,9 @@ shift::shift(int B, int C,
default: default:
throw std::runtime_error("unsupported input layout"); throw std::runtime_error("unsupported input layout");
} }
// Shift edge
shift_edge_h_ = (AH_ == stride_h_);
shift_edge_w_ = (AW_ == stride_w_);
// B memory strides: [C, F] // B memory strides: [C, F]
ldb_n_ = 1; ldb_n_ = 1;
ldb_h_ = 1; ldb_h_ = 1;
@@ -88,7 +92,7 @@ shift::shift(int B, int C,
if(layout_ == NCHW) if(layout_ == NCHW)
shapes_c_ = {B, F, CH_, CW_}; shapes_c_ = {B, F, CH_, CW_};
// Weight gradient // Weight gradient
if(ty_ == WGRAD){ if(op_ == WGRAD){
// b <-> c // b <-> c
// b <-> a // b <-> a
std::swap(ldb_n_, ldc_n_); std::swap(ldb_n_, ldc_n_);
@@ -106,7 +110,7 @@ shift::shift(int B, int C,
shapes_c_ = {C, F}; shapes_c_ = {C, F};
} }
// Input gradient // Input gradient
if(ty_ == BPROP){ if(op_ == BPROP){
// a <-> c // a <-> c
std::swap(lda_n_, ldc_n_); std::swap(lda_n_, ldc_n_);
std::swap(lda_w_, ldc_w_); std::swap(lda_w_, ldc_w_);
@@ -128,10 +132,12 @@ base* shift::clone() const {
void shift::build_delta_a() { void shift::build_delta_a() {
h_delta_a.resize(MAX_C_); h_delta_a.resize(MAX_C_);
if(ty_ == FPROP){ auto shift_h = [&](int c) { return shift_edge_h_ ? std::max(0, shift_h_[c]) : shift_h_[c]; };
auto shift_w = [&](int c) { return shift_edge_w_ ? std::max(0, shift_w_[c]) : shift_w_[c]; };
if(op_ == FPROP){
// compute offset // compute offset
auto offset = [&](unsigned c) { auto offset = [&](unsigned c) {
return c*lda_c_ + shift_h_[c]*lda_h_ + shift_w_[c]*lda_w_; return c*lda_c_ + shift_h(c)*lda_h_ + shift_w(c)*lda_w_;
}; };
// populate look-up table // populate look-up table
for(unsigned c = 0; c < TK_; c++) for(unsigned c = 0; c < TK_; c++)
@@ -139,14 +145,14 @@ void shift::build_delta_a() {
for(unsigned c = 0; c < C_; c++) for(unsigned c = 0; c < C_; c++)
h_delta_a[TK_ + c] = offset(c + TK_) - offset(c); h_delta_a[TK_ + c] = offset(c + TK_) - offset(c);
} }
if(ty_ == BPROP){ if(op_ == BPROP){
for(unsigned c = 0; c < C_; c++){ for(unsigned c = 0; c < C_; c++){
h_delta_a[c] = shift_h_[c]*ldc_h_ + shift_w_[c]*ldc_w_; h_delta_a[c] = shift_h(c)*ldc_h_ + shift_w(c)*ldc_w_;
} }
} }
if(ty_ == WGRAD){ if(op_ == WGRAD){
for(unsigned c = 0; c < C_; c++) for(unsigned c = 0; c < C_; c++)
h_delta_a[c] = shift_h_[c]*ldb_h_ + shift_w_[c]*ldb_w_; h_delta_a[c] = shift_h(c)*ldb_h_ + shift_w(c)*ldb_w_;
} }
} }
@@ -167,10 +173,22 @@ bool shift::operator <(const base& other) const{
auto *y = dynamic_cast<const shift*>(&other); auto *y = dynamic_cast<const shift*>(&other);
if(!y) if(!y)
return true; return true;
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_, return std::tie(B_, C_, F_,
shift_h_, shift_w_, ty_, bias_) AD_, AH_, AW_,
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_, BD_, BH_, BW_,
y->shift_h_, y->shift_w_, y->ty_, y->bias_); CD_, CH_, CW_,
shift_h_, shift_w_,
stride_h_, stride_w_,
layout_, op_,
bias_)
< std::tie(y->B_, y->C_, y->F_,
y->AD_, y->AH_, y->AW_,
y->BD_, y->BH_, y->BW_,
y->CD_, y->CH_, y->CW_,
y->shift_h_, y->shift_w_,
y->stride_h_, y->stride_w_,
y->layout_, y->op_,
y->bias_);
} }
void shift::init_impl(driver::stream *stream, driver::cu_module *module) { void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
@@ -212,7 +230,7 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(26, CW_); kernel->setArg(26, CW_);
unsigned TM = ranges[0], TN = ranges[1]; unsigned TM = ranges[0], TN = ranges[1];
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
if(ty_ == BPROP) if(op_ == BPROP)
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4); ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4);
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
@@ -263,7 +281,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *A,
int32 pad_w = BW / 2;)"; int32 pad_w = BW / 2;)";
/* A offsets */ /* A offsets */
if(ty_ == FPROP){ if(op_ == FPROP){
os << R"( os << R"(
int32 rawh[TM] = rxa / NB; int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB; int32 rab[TM] = rxa % NB;
@@ -274,13 +292,20 @@ if(ty_ == FPROP){
__constant__ int32* pd[TK] = delta_a + rka; __constant__ int32* pd[TK] = delta_a + rka;
multiple_of(4) int32 d[TK] = *pd; multiple_of(4) int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :]; int32 offa_interior[TM, TK] = d[newaxis, :];
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;\n)";
int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); if(shift_edge_h_)
int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); os << " int1 interiorh[TM] = 1;";
else
os << " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));";
if(shift_edge_w_)
os << " int1 interiorw[TM] = 1;";
else
os << " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
os << R"(
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis]; int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
} }
if(ty_ == BPROP){ if(op_ == BPROP){
os << R"( os << R"(
int32 rawh[TM] = rxa / NB; int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB; int32 rab[TM] = rxa % NB;
@@ -290,12 +315,12 @@ if(ty_ == BPROP){
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
} }
if(ty_ == WGRAD && layout_ == CHWN){ if(op_ == WGRAD && layout_ == CHWN){
os << R"( os << R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offa1[TK, TM] = rka[:, newaxis];)"; int32 offa1[TK, TM] = rka[:, newaxis];)";
} }
if(ty_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
os << R"( os << R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 rawh[TK] = rka / NB; int32 rawh[TK] = rka / NB;
@@ -307,17 +332,17 @@ if(ty_ == WGRAD && layout_ == NCHW){
} }
/* B offsets */ /* B offsets */
if(ty_ == FPROP){ if(op_ == FPROP){
os << R"( os << R"(
int32 offb0[TN, TK] = ryb[:, newaxis]; int32 offb0[TN, TK] = ryb[:, newaxis];
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
} }
if(ty_ == BPROP){ if(op_ == BPROP){
os << R"( os << R"(
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = rkb[:, newaxis];)"; int32 offb1[TK, TN] = rkb[:, newaxis];)";
} }
if(ty_ == WGRAD){ if(op_ == WGRAD){
os << R"( os << R"(
__constant__ int32* pd[TN] = delta_a + ryb; __constant__ int32* pd[TN] = delta_a + ryb;
int32 d[TN] = *pd; int32 d[TN] = *pd;
@@ -326,9 +351,16 @@ if(ty_ == WGRAD){
int32 rbb[TK] = rkb % NB; int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = (rbwh % CW)*stride_w; int32 rbw[TK] = (rbwh % CW)*stride_w;
int32 rbh[TK] = (rbwh / CW)*stride_h; int32 rbh[TK] = (rbwh / CW)*stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); if(shift_edge_h_)
int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); os << " int1 interiorh[TK] = 1;\n";
else
os << " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
os << " int1 interiorw[TK] = 1;\n";
else
os << " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
os << R"(
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
int32 incb[TK, TN] = interior ? shift : 0; int32 incb[TK, TN] = interior ? shift : 0;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
@@ -349,7 +381,7 @@ if(ty_ == WGRAD){
int1 checkb[)" << BS << R"(] = k > TK;)"; int1 checkb[)" << BS << R"(] = k > TK;)";
/* Increment A pointers */ /* Increment A pointers */
if(ty_ == FPROP){ if(op_ == FPROP){
os << R"( os << R"(
pd = pd + TK; pd = pd + TK;
d = *pd; d = *pd;
@@ -358,15 +390,15 @@ if(ty_ == FPROP){
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior; int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
pa = pa + offa;)"; pa = pa + offa;)";
} }
if(ty_ == BPROP){ if(op_ == BPROP){
os << R"( os << R"(
pa = pa + TK * lda_c;)"; pa = pa + TK * lda_c;)";
} }
if(ty_ == WGRAD && layout_ == CHWN){ if(op_ == WGRAD && layout_ == CHWN){
os << R"( os << R"(
pa = pa + TK;)"; pa = pa + TK;)";
} }
if(ty_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
os << R"( os << R"(
rka = rka + TK; rka = rka + TK;
rawh = rka / NB; rawh = rka / NB;
@@ -380,25 +412,32 @@ if(ty_ == WGRAD && layout_ == NCHW){
@checka a = *pa;)"; @checka a = *pa;)";
/* Increment B pointers */ /* Increment B pointers */
if(ty_ == WGRAD){ if(op_ == WGRAD){
os << R"( os << R"(
rkb = rkb + TK; rkb = rkb + TK;
rbwh = rkb / NB; rbwh = rkb / NB;
rbb = rkb % NB; rbb = rkb % NB;
rbw = (rbwh % CW)*stride_w; rbw = (rbwh % CW)*stride_w;
rbh = (rbwh / CW)*stride_h; rbh = (rbwh / CW)*stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)";
interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h)); if(shift_edge_h_)
interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w)); os << " interiorh = 1;\n";
else
os << " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
os << " interiorw = 1;\n";
else
os << " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));\n";
os << R"(
interior = interiorh[:, newaxis] && interiorw[:, newaxis]; interior = interiorh[:, newaxis] && interiorw[:, newaxis];
incb = interior ? shift : 0; incb = interior ? shift : 0;
pb = B + offb0 + offkb[:, newaxis] + incb;)"; pb = B + offb0 + offkb[:, newaxis] + incb;)";
} }
if(ty_ == FPROP){ if(op_ == FPROP){
os << R"( os << R"(
pb = pb + TK * ldb_c;)"; pb = pb + TK * ldb_c;)";
} }
if(ty_ == BPROP){ if(op_ == BPROP){
os << R"( os << R"(
pb = pb + TK;)"; pb = pb + TK;)";
} }
@@ -409,7 +448,7 @@ if(ty_ == BPROP){
int32 ryc[TN] = get_global_range[TN](1);)"; int32 ryc[TN] = get_global_range[TN](1);)";
/* C offsets */ /* C offsets */
if(ty_ == BPROP){ if(op_ == BPROP){
os << R"( os << R"(
int32 rcwh[TM] = rxc / NB; int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB; int32 rcb[TM] = rxc % NB;
@@ -418,7 +457,7 @@ if(ty_ == BPROP){
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
)"; )";
} }
if(ty_ == FPROP){ if(op_ == FPROP){
os << R"( os << R"(
int32 rcwh[TM] = rxc / NB; int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB; int32 rcb[TM] = rxc % NB;
@@ -427,7 +466,7 @@ if(ty_ == FPROP){
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;
)"; )";
} }
if(ty_ == WGRAD){ if(op_ == WGRAD){
os << R"( os << R"(
int32 offxc[TM] = rxc; int32 offxc[TM] = rxc;
)"; )";
@@ -437,10 +476,17 @@ if(ty_ == WGRAD){
int1 checkc0[TM] = rxc < M; int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N; int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(ty_ == BPROP){ if(op_ == BPROP){
os << "\n";
if(shift_edge_h_)
os << " int1 interiorh[TM] = 1;\n";
else
os << " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
if(shift_edge_w_)
os << " int1 interiorw[TM] = 1;\n";
else
os << " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));\n";
os << R"( os << R"(
int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
__constant__ int32* pd[TN] = delta_a + ryc; __constant__ int32* pd[TN] = delta_a + ryc;
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];