[dnn/conv]: now using look-up table for wgrad computation as well
This commit is contained in:
@@ -38,12 +38,6 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// look-up table
|
||||
std::vector<int> h_delta, h_masks;
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
configuration.build_deltas(h_delta);
|
||||
configuration.build_masks(h_masks);
|
||||
}
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
@@ -51,12 +45,7 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
}
|
||||
configuration.init(stream, jit);
|
||||
stream->synchronize();
|
||||
configuration.set_arg(kernel, da, db, dc);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
@@ -66,7 +55,7 @@ int main() {
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = configuration.src();
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
@@ -74,7 +63,7 @@ int main() {
|
||||
stream->read(dc, true, 0, hc);
|
||||
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
for(size_t i = 0; i < hc.size(); i++){
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
@@ -68,7 +68,7 @@ int main() {
|
||||
stream->read(dc, true, 0, hc);
|
||||
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include <numeric>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/jit.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
@@ -46,6 +47,9 @@ public:
|
||||
// swap b and c for wgrad
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
std::swap(BD_, CD_);
|
||||
std::swap(BH_, CH_);
|
||||
std::swap(BW_, CW_);
|
||||
}
|
||||
// leading dimensions
|
||||
auto set_ld = [](const std::vector<int32_t>& shapes,
|
||||
@@ -62,6 +66,8 @@ public:
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// equivalent matmul
|
||||
b_trans_ = ty_ != BPROP;
|
||||
b_lut_ = ty_ == WGRAD;
|
||||
if(ty_ == WGRAD){
|
||||
M_ = shapes_c_[0]*shapes_c_[1]*shapes_c_[2]*shapes_c_[3];
|
||||
N_ = shapes_c_[4];
|
||||
@@ -73,11 +79,20 @@ public:
|
||||
K_ = shapes_b_[0]*shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
}
|
||||
// look-up table info
|
||||
Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
if(ty_ == BPROP)
|
||||
Fs_ *= shapes_b_[4];
|
||||
if(ty_ == FPROP)
|
||||
Fs_ = shapes_b_[1]*shapes_b_[2]*shapes_b_[3];
|
||||
else
|
||||
Fs_ = K_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
build_deltas();
|
||||
build_masks();
|
||||
size_t cst_size = h_b_deltas_.size()*4;
|
||||
is_b_deltas_cst_ = cst_size < 65536;
|
||||
cst_size += h_a_deltas_.size()*4;
|
||||
is_a_deltas_cst = cst_size < 65536;
|
||||
cst_size += h_masks_.size()*4;
|
||||
is_mask_cst_ = cst_size < 65536;
|
||||
}
|
||||
|
||||
size_t a_size() {
|
||||
@@ -99,14 +114,14 @@ public:
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
void build_deltas(std::vector<int>& deltas){
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
deltas.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
void build_deltas(){
|
||||
h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
if(b_lut_)
|
||||
h_b_deltas_.resize(Luts_);
|
||||
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_;
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_;
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
@@ -119,7 +134,7 @@ public:
|
||||
};
|
||||
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
deltas[i] = (((i + TK_) % Luts_) - i);
|
||||
h_a_deltas_[i] = (((i + TK_) % Luts_) - i);
|
||||
|
||||
size_t Ds0 = Luts_;
|
||||
size_t Ds1 = upsample_w_;
|
||||
@@ -128,7 +143,7 @@ public:
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &deltas[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
int32_t* deltas_ptr = &h_a_deltas_[Luts_ + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
// unpack
|
||||
@@ -145,18 +160,31 @@ public:
|
||||
int32_t rdiff = (nextr + ph)/upsample_h_ - (r + ph)/upsample_h_;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w_ - (s + pw)/upsample_w_;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
if(ty_ == WGRAD)
|
||||
deltas_ptr[i] = cdiff*ld_a_[0] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
else
|
||||
deltas_ptr[i] = cdiff*ld_a_[1] + tdiff*ld_a_[2] + rdiff*ld_a_[3] + sdiff*ld_a_[4];
|
||||
}
|
||||
}
|
||||
|
||||
if(ty_ == WGRAD){
|
||||
for(size_t i = 0; i < Ds0; ++i) {
|
||||
int32_t c, t, r, s;
|
||||
int32_t nextc, nextt, nextr, nexts;
|
||||
std::tie(c, t, r, s) = unpack(i);
|
||||
std::tie(nextc, nextt, nextr, nexts) = unpack(i + TK_);
|
||||
int32_t cdiff = nextc - c, tdiff = nextt - t, rdiff = nextr - r, sdiff = nexts - s;
|
||||
h_b_deltas_[i] = cdiff*ld_b_[0] + tdiff*ld_b_[2] + rdiff*ld_b_[3] + sdiff*ld_b_[4];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void build_masks(std::vector<int>& masks){
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
masks.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
void build_masks(){
|
||||
h_masks_.resize(Luts_ + (2*pad_h_+1)*(2*pad_w_+1)*(2*pad_d_+1)*Luts_);
|
||||
|
||||
auto unpack = [&](int32_t ltrs){
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / Fs_;
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % Fs_;
|
||||
int32_t l = (ty_ == BPROP) ? ltrs % NF_ : ltrs / (BD_*BH_*BW_);
|
||||
int32_t trs = (ty_ == BPROP) ? ltrs / NF_ : ltrs % (BD_*BH_*BW_);
|
||||
int32_t tr = trs / BW_;
|
||||
int32_t s = trs % BW_;
|
||||
int32_t t = tr / BH_;
|
||||
@@ -174,7 +202,7 @@ public:
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
int32_t* masks_ptr = &h_masks_[Luts_ + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t l, t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
@@ -189,7 +217,7 @@ public:
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
masks[i] = 0x0;
|
||||
h_masks_[i] = 0x0;
|
||||
}
|
||||
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN){
|
||||
@@ -200,6 +228,27 @@ public:
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
void init(driver::stream *stream, triton::jit &jit) {
|
||||
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
|
||||
if(host.empty())
|
||||
return nullptr;
|
||||
size_t nbytes = host.size()*4;
|
||||
// get buffer
|
||||
triton::driver::buffer* buffer;
|
||||
if(is_cst)
|
||||
buffer = jit.get_buffer(name);
|
||||
else
|
||||
buffer = triton::driver::buffer::create(stream->context(), nbytes);
|
||||
// copy
|
||||
stream->write(buffer, false, 0, nbytes, host.data());
|
||||
return buffer;
|
||||
};
|
||||
|
||||
d_a_deltas_ = init_lut(is_a_deltas_cst, "delta", h_a_deltas_);
|
||||
d_b_deltas_ = init_lut(is_b_deltas_cst_, "b_delta", h_b_deltas_);
|
||||
d_masks_ = init_lut(is_mask_cst_, "masks", h_masks_);
|
||||
}
|
||||
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c)
|
||||
{
|
||||
@@ -211,70 +260,107 @@ public:
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, AH_);
|
||||
kernel->setArg(7, AW_);
|
||||
kernel->setArg(8, BH_);
|
||||
kernel->setArg(9, BW_);
|
||||
kernel->setArg(10, CH_);
|
||||
kernel->setArg(11, CW_);
|
||||
// A arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(8, CH_);
|
||||
kernel->setArg(9, CW_);
|
||||
kernel->setArg(10, BH_);
|
||||
kernel->setArg(11, BW_);
|
||||
kernel->setArg(12, ld_a_[1]);
|
||||
kernel->setArg(13, ld_a_[0]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(8, BH_);
|
||||
kernel->setArg(9, BW_);
|
||||
kernel->setArg(10, CH_);
|
||||
kernel->setArg(11, CW_);
|
||||
kernel->setArg(12, ld_a_[0]);
|
||||
kernel->setArg(13, ld_a_[1]);
|
||||
}
|
||||
kernel->setArg(12, ld_a_[0]);
|
||||
kernel->setArg(13, ld_a_[1]);
|
||||
kernel->setArg(14, ld_a_[2]);
|
||||
kernel->setArg(15, ld_a_[3]);
|
||||
kernel->setArg(16, ld_a_[4]);
|
||||
kernel->setArg(17, ld_b_[0]);
|
||||
kernel->setArg(18, ld_b_[1]);
|
||||
kernel->setArg(19, ld_b_[2]);
|
||||
kernel->setArg(20, ld_b_[3]);
|
||||
kernel->setArg(21, ld_b_[4]);
|
||||
kernel->setArg(22, ld_c_[0]);
|
||||
kernel->setArg(23, ld_c_[1]);
|
||||
kernel->setArg(24, ld_c_[2]);
|
||||
kernel->setArg(25, ld_c_[3]);
|
||||
kernel->setArg(26, ld_c_[4]);
|
||||
// B arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(17, ld_b_[0]);
|
||||
kernel->setArg(18, ld_b_[2]);
|
||||
kernel->setArg(19, ld_b_[3]);
|
||||
kernel->setArg(20, ld_b_[4]);
|
||||
kernel->setArg(21, ld_b_[1]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(17, ld_b_[0]);
|
||||
kernel->setArg(18, ld_b_[1]);
|
||||
kernel->setArg(19, ld_b_[2]);
|
||||
kernel->setArg(20, ld_b_[3]);
|
||||
kernel->setArg(21, ld_b_[4]);
|
||||
}
|
||||
// C arguments
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(22, ld_c_[0]);
|
||||
kernel->setArg(23, ld_c_[4]);
|
||||
kernel->setArg(24, ld_c_[1]);
|
||||
kernel->setArg(25, ld_c_[2]);
|
||||
kernel->setArg(26, ld_c_[3]);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(22, ld_c_[0]);
|
||||
kernel->setArg(23, ld_c_[1]);
|
||||
kernel->setArg(24, ld_c_[2]);
|
||||
kernel->setArg(25, ld_c_[3]);
|
||||
kernel->setArg(26, ld_c_[4]);
|
||||
}
|
||||
kernel->setArg(27, pad_h_);
|
||||
kernel->setArg(28, pad_w_);
|
||||
size_t idx = 29;
|
||||
if(!is_a_deltas_cst)
|
||||
kernel->setArg(idx++, d_a_deltas_);
|
||||
if(!is_b_deltas_cst_)
|
||||
kernel->setArg(idx++, d_b_deltas_);
|
||||
if(!is_mask_cst_)
|
||||
kernel->setArg(idx++, d_masks_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> default_params() {
|
||||
if(ty_ == FPROP)
|
||||
if(ty_==FPROP)
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4};
|
||||
else if(ty_ == BPROP)
|
||||
return {32, 2, 64, 32, 64, 32, 4, 2, 2, 4, 2, 8, 4, 2};
|
||||
else
|
||||
return {8, 2, 16, 8, 2, 16, 8, 2, 8, 8};
|
||||
else if(ty_ == WGRAD)
|
||||
return {32, 2, 64, 32, 2, 64, 16, 8, 2, 2, 4, 2, 8};
|
||||
}
|
||||
|
||||
|
||||
std::string xprop() {
|
||||
bool trans_b = ty_ == FPROP;
|
||||
std::string BS = trans_b ?"[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = trans_b ?"[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = trans_b ?"[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = trans_b ?"*ldb_s" : "";
|
||||
std::string ldb1 = trans_b ?"" : "*ldb_c";
|
||||
std::string useb = trans_b ?"trans(b)" : "b";
|
||||
std::string flipr = trans_b?"" : "BH - 1 -";
|
||||
std::string flips = trans_b?"" : "BW - 1 -";
|
||||
std::string ax = trans_b?"crs" : "rsc";
|
||||
std::vector<std::string> redax = {"BH", "BW", "N"};
|
||||
if(trans_b)
|
||||
std::string src() {
|
||||
bool is_wgrad = ty_ == WGRAD;
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
|
||||
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
|
||||
std::string useb = b_trans_ ? "trans(b)" : "b";
|
||||
std::string flipr = b_trans_ ? "" : "BH - 1 -";
|
||||
std::string flips = b_trans_ ? "" : "BW - 1 -";
|
||||
std::string ax = b_trans_ ? "crs" : "rsc";
|
||||
std::vector<std::string> redax;
|
||||
if(b_trans_)
|
||||
redax = {"C", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "N"};
|
||||
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[1024];
|
||||
__constant__ int32* masks = alloc_const int32[4096];
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
res += R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
@@ -286,13 +372,20 @@ public:
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w){
|
||||
int32 pad_h, int32 pad_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta\n";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta\n";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks\n";
|
||||
res += R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 Fs = )" + std::to_string(Fs_) + R"(;
|
||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
@@ -305,16 +398,31 @@ public:
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb)" + BS + " = b + rkb" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + Fs + rka;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
int32 rbc[TK] = rbcr / BH;
|
||||
int32 rbr[TK] = rbcr % BH;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
|
||||
int32 db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + Fs + maskw*Fs + maskh*Fs*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
@@ -324,9 +432,15 @@ public:
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
pdb = pdb + incd;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
@@ -342,86 +456,17 @@ public:
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
// where A is N,C,AH,AW
|
||||
// B is N,K,BH,BW
|
||||
// C is C,CH,CW,K
|
||||
std::string wgrad() {
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
|
||||
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
|
||||
int32 pad_h, int32 pad_w){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rk[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 racr[TM] = rxa / CW;
|
||||
int32 raw_base[TM] = rxa % CW - pad_w;
|
||||
int32 rac[TM] = racr / CH;
|
||||
int32 rah_base[TM] = racr % CH - pad_h;
|
||||
fp32* pa_base[TM, TK] = a + rac[:, newaxis]*lda_c;
|
||||
fp32* pb_base[TN, TK] = b + ryb[:, newaxis]*ldb_k;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 rknp[TK] = rk / BW;
|
||||
int32 rkq[TK] = rk % BW;
|
||||
int32 rkn[TK] = rknp / BH;
|
||||
int32 rkp[TK] = rknp % BH;
|
||||
int32 rah[TM, TK] = rah_base[:, newaxis] + rkp[newaxis, :];
|
||||
int32 raw[TM, TK] = raw_base[:, newaxis] + rkq[newaxis, :];
|
||||
int1 checka[TM, TK] = (rah >= 0) && (rah < AH) && (raw >= 0) && (raw < AW);
|
||||
fp32* pa[TM, TK] = pa_base + rah*lda_h + raw*lda_w + rkn*lda_n;
|
||||
fp32* pb[TN, TK] = pb_base + rkp*ldb_p + rkq*ldb_q + rkn*ldb_n;
|
||||
fp32 A[TM, TK] = checka ? *pa : 0;
|
||||
fp32 B[TN, TK] = *pb;
|
||||
C = dot(A, trans(B), C);
|
||||
rk = rk + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
int32 rccr[TM] = rxc / CW;
|
||||
int32 rcs[TM] = rxa % CW;
|
||||
int32 rcc[TM] = racr / CH;
|
||||
int32 rcr[TM] = racr % CH;
|
||||
int32 rc0[TM] = rcc*ldc_c + rcr*ldc_r + rcs*ldc_s;
|
||||
fp32* pc[TM, TN] = c + rc0[:, newaxis] + ryc[newaxis, :]*ldc_k;
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string src() {
|
||||
if(ty_ == FPROP || ty_ == BPROP)
|
||||
return xprop();
|
||||
else
|
||||
return wgrad();
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
@@ -552,9 +597,20 @@ private:
|
||||
// memory stride for C
|
||||
std::vector<int32_t> shapes_c_;
|
||||
std::vector<int32_t> ld_c_;
|
||||
// constant memory
|
||||
std::vector<int32_t> h_a_deltas_;
|
||||
std::vector<int32_t> h_b_deltas_;
|
||||
std::vector<int32_t> h_masks_;
|
||||
driver::buffer* d_a_deltas_;
|
||||
driver::buffer* d_b_deltas_;
|
||||
driver::buffer* d_masks_;
|
||||
bool is_a_deltas_cst;
|
||||
bool is_b_deltas_cst_;
|
||||
bool is_mask_cst_;
|
||||
// type
|
||||
type ty_;
|
||||
bool is_bprop_;
|
||||
bool b_trans_;
|
||||
bool b_lut_;
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user