[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -10,7 +10,6 @@
int main() { int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::FPROP; triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization // initialization
int32_t B = 16, NF = 128; int32_t B = 16, NF = 128;
@@ -19,8 +18,12 @@ int main() {
int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0); // triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty); triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
upsample_d, upsample_h, upsample_w,
"fp32", "fp32", ty, 0);
// convolution configuration // convolution configuration
std::vector<float> hc(configuration.c_size()); std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size()); std::vector<float> rc(configuration.c_size());
@@ -42,29 +45,7 @@ int main() {
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
// benchmark a given convolution kernel configuration.enqueue(stream, {da, db, dc, nullptr});
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, nullptr);
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc); stream->read(dc, true, 0, hc);
configuration.cpu_ref(rc.data(), ha.data(), hb.data()); configuration.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++){ for(size_t i = 0; i < hc.size(); i++){

View File

@@ -10,18 +10,14 @@
int main() { int main() {
bool AT = true; bool AT = true;
bool BT = true; bool BT = true;
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters // matrix multiplication parameters
int32_t M = 1024, N = 1024, K = 1024; int32_t M = 128, N = 128, K = 128;
std::vector<float> hc(M*N); std::vector<float> hc(M*N);
std::vector<float> rc(M*N); std::vector<float> rc(M*N);
std::vector<float> ha(M*K); std::vector<float> ha(M*K);
std::vector<float> hb(K*N); std::vector<float> hb(K*N);
std::vector<int32_t> hlocks(2048);
srand(0); srand(0);
for(size_t i = 0; i < ha.size(); i++) for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX; ha[i] = (float)rand()/RAND_MAX;
@@ -32,41 +28,15 @@ int main() {
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context); triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha); stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
triton::dnn::gemm::init(stream, dlocks);
stream->synchronize(); stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
gemm.enqueue(stream, {da, db, dc});
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, context->device());
return 2.*M*N*K / ts * 1e-3;
};
// just-in-time compile source-code
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 4, 4);
// jit.autotune("matmul",src.c_str(), benchmark);
jit.add_module("matmul", src.c_str(), {8, 16, 4, 2, 16, 8, 4, 2, 2, 4, 2, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
stream->read(dc, true, 0, hc); stream->read(dc, true, 0, hc);
triton::dnn::gemm::cpu_ref<float>(AT, BT, rc, ha, hb, M, N, K); gemm.cpu_ref<float>(rc, ha, hb);
for(size_t i = 0; i < M*N; i++) for(size_t i = 0; i < M*N; i++)
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

View File

@@ -29,7 +29,7 @@ int main() {
shift_w[c] = rand() % S - S/2; shift_w[c] = rand() % S - S/2;
} }
// configuration // configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP); triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
// host buffers // host buffers
std::vector<float> hc(shift.c_size()); std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size()); std::vector<float> rc(shift.c_size());
@@ -53,7 +53,7 @@ int main() {
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
shift.enqueue(stream, da, db, dc); shift.enqueue(stream, {da, db, dc});
// stream->read(dc, true, 0, hc); // stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data()); // shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++) // for(size_t i = 0; i < hc.size(); i++)

View File

@@ -72,7 +72,7 @@ torch::Tensor shift_common(
if(m_shift_jit.find(key) == m_shift_jit.end()){ if(m_shift_jit.find(key) == m_shift_jit.end()){
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get(); jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss; std::ostringstream oss;
configuration->get_src(oss); configuration->triton_c_src(oss);
std::string src = oss.str(); std::string src = oss.str();
// benchmark a given shiftolution kernel // benchmark a given shiftolution kernel
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
@@ -85,7 +85,7 @@ torch::Tensor shift_common(
stream->synchronize(); stream->synchronize();
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); }, double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device()); [&](){ stream->synchronize(); }, stream->context()->device());
return configuration->get_nflops() / ts * 1e-3; return configuration->num_flops() / ts * 1e-3;
}; };
// auto-tune and save result // auto-tune and save result
if(autotune) { if(autotune) {

View File

@@ -59,15 +59,7 @@ public:
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false); triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
// create config // create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss; batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
} }
private: private:
@@ -135,15 +127,7 @@ public:
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false); triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config // create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss; batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
} }
private: private:

View File

@@ -50,56 +50,28 @@ public:
int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false; bool has_bias = false;
// wrap buffers
// get conv configuration
triton::dnn::conv configuration(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false); triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::buffer* bias = nullptr; triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// allocate output // allocate output
auto c_shapes = configuration.c_shapes(); auto c_shapes = conv.c_shapes();
Tensor* tfc = nullptr; Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}); TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc)); OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false); triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
// enqueue
// benchmark a given convolution kernel conv.enqueue(stream, {&a, &b, &c, bias});
triton::jit jit(ctx);
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << benchmark(kernel, info) << std::endl;
} }
}; };

View File

@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
// get inputs // get inputs
const Tensor& a = context->input(0); const Tensor& a = context->input(0);
const Tensor& b = context->input(1); const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes // get shapes
const int32_t M = a.dim_size(0); const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0); const int32_t N = b.dim_size(0);
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
// return early if possible // return early if possible
if (out_shape.num_elements() == 0) if (out_shape.num_elements() == 0)
return; return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters // matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false); // template
// benchmark a given matrix multiplication kernel triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
auto benchmark = [&](triton::driver::kernel* kernel, dot.enqueue(stream, {&da, &db, &dc});
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, ctx->device());
return 2.*M*N*K / ts * 1e-3;
};
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
// just-in-time compile source-code
jit.autotune("matmul", src.c_str(), benchmark);
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;
} }
private: private:
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
REGISTER_OP("Dot") REGISTER_OP("Dot")
.Input("a: float16") .Input("a: float16")
.Input("b: float16") .Input("b: float16")
.Input("locks: int32")
.Output("c: float32") .Output("c: float32")
; ;

View File

@@ -127,5 +127,5 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n))) print(np.max(np.abs(db_t - db_n)))
run_shift() #run_shift()
#run_batchnorm() run_batchnorm()

View File

@@ -19,15 +19,6 @@
using namespace tensorflow; using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice; using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP> template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel { class ShiftConvOp : public OpKernel {
public: public:
@@ -102,23 +93,15 @@ public:
// shift offsets // shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data(); int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data(); int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration // create configuration
triton::dnn::shift* shift; triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
if(m_config.find(key) == m_config.end()) stride_h_, stride_w_,
shift = m_config.emplace(key, new triton::dnn::shift( shift_h_data, shift_w_data,
B, C, D, H, W, T, R_, S_, F, "fp32", "fp32", OP, has_bias);
stride_h_, stride_w_,
shift_h, shift_w,
"fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
// shapes for c // shapes for c
std::vector<int64> c_shapes; std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes()) for(int32_t x: shift.c_shapes())
c_shapes.push_back(x); c_shapes.push_back(x);
TensorShape out_shapes(c_shapes); TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr; Tensor* tf_c = nullptr;
@@ -130,7 +113,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
shift->enqueue(stream, {&da, &db, &dc}); shift.enqueue(stream, {&da, &db, &dc});
} }
private: private:

View File

@@ -29,7 +29,11 @@
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
class base { class base {
friend class cmp_recompile;
protected: protected:
// leading dimensions // leading dimensions
static void set_ld(const std::vector<int32_t>& shapes, static void set_ld(const std::vector<int32_t>& shapes,
@@ -37,21 +41,24 @@ protected:
private: private:
// initialize // initialize
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0; virtual void init_impl(driver::stream *, driver::cu_module *){ }
// enqueue // enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args, std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads) = 0; const std::vector<unsigned>& ranges,
size_t nthreads) = 0;
// number of flops
virtual size_t num_flops() const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
public: public:
// constructor // constructor
base(const std::string& name); base(const std::string& name);
// number of flops
virtual size_t get_nflops() const = 0;
// triton-c source // triton-c source
virtual void get_src(std::ostream &os) const = 0; virtual void triton_c_src(std::ostream &os) const = 0;
// comparison for maps // clone
virtual bool operator<(const base& other) const = 0; virtual base* clone() const = 0;
// enqueue // enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args); void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
@@ -59,6 +66,12 @@ private:
std::string name_; std::string name_;
}; };
struct cmp_recompile{
bool operator()(base* x, base* y) const{
return *x < *y;
}
};
} }
} }

View File

@@ -28,23 +28,32 @@
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <cmath> #include <cmath>
#include "triton/dnn/base.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
class batchnorm_forward { class batchnorm_forward: public base {
private:
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned> &ranges, size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public: public:
// constructor // constructor
batchnorm_forward(int C, int D, int H, int W, int B, std::string ty = "fp32"); batchnorm_forward(int C, int D, int H, int W, int B,
// enqueue std::string ty = "fp32", float eps = 1e-5);
void enqueue(driver::stream *stream, driver::kernel *kernel, // triton-c source
driver::buffer *y, driver::buffer *m, driver::buffer *v, void triton_c_src(std::ostream &os) const;
driver::buffer *x, driver::buffer *g, driver::buffer *b,
size_t TM, size_t nthreads);
// triton-c source code
void src(std::ostream &os);
private: private:
int32_t C_; int32_t C_;
@@ -58,18 +67,25 @@ private:
float rcpDHWB_; float rcpDHWB_;
}; };
class batchnorm_backward { class batchnorm_backward: public base{
private:
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned> &ranges, size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public: public:
// constructor // constructor
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5); batchnorm_backward(int C, int D, int H, int W, int B,
// enqueue std::string ty = "fp32", float eps = 1e-5);
void enqueue(driver::stream *stream, driver::kernel *kernel, // triton-c source
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy, void triton_c_src(std::ostream &os) const;
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
size_t TM, size_t nthreads);
// triton-c source code
void src(std::ostream &os);
private: private:
int32_t C_; int32_t C_;

View File

@@ -4,11 +4,12 @@
#include <numeric> #include <numeric>
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
class conv { class conv: public base{
public: public:
enum type { enum type {
FPROP, FPROP,
@@ -17,11 +18,29 @@ public:
}; };
private: private:
void set_ld(const std::vector<int32_t>& shapes, // initialize
std::vector<int32_t>& ld);
std::tuple<int32_t, int32_t, int32_t, int32_t> std::tuple<int32_t, int32_t, int32_t, int32_t>
unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW); unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW);
void build_b_deltas();
void build_a_deltas();
void build_masks();
void init_impl(driver::stream *, driver::cu_module *);
// enqueue
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias);
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public: public:
@@ -39,35 +58,17 @@ public:
size_t b_size(); size_t b_size();
size_t c_size(); size_t c_size();
std::vector<int32_t> c_shapes(); std::vector<int32_t> c_shapes();
// default params
// initialize
void build_b_deltas();
void build_deltas();
void build_masks();
void init(driver::stream *stream, driver::cu_module *module);
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias);
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias,
size_t TM, size_t TN, size_t GZ, size_t nthreads);
// utilities
size_t get_nflops();
std::vector<unsigned> default_params(); std::vector<unsigned> default_params();
// source // triton-c source code
void src(std::ostream &os); void triton_c_src(std::ostream &os) const;
// cpu check // cpu reference implementations
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B); void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);

View File

@@ -1,27 +1,44 @@
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
#include <string> #include <string>
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
class gemm { class gemm: public base {
private:
// initialize
void init_impl(driver::stream *, driver::cu_module *);
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
public: public:
static void init(driver::stream* stream, driver::buffer* locks); gemm(int M, int N, int K, bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb);
static void set_arg(driver::kernel *kernel, // triton-c source
driver::buffer *a, driver::buffer *b, driver::buffer *c, void triton_c_src(std::ostream &os) const;
int32_t M, int32_t N, int32_t K,
driver::buffer *locks, int32_t grid_0, int32_t grid_1);
static std::vector<unsigned> default_params(bool AT, bool BT); // clone
base* clone() const;
static std::string src(bool AT, bool BT, // default params
std::string a_ty, std::string b_ty, std::vector<unsigned> default_params();
unsigned alignment_lda, unsigned alignment_ldb);
// CPU reference implementation
template<class T, bool AT, bool BT> template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++) for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){ for(size_t n = 0; n < N; n++){
T acc = 0; T acc = 0;
@@ -30,18 +47,29 @@ public:
c[m + n*M] = acc; c[m + n*M] = acc;
} }
} }
template<class T> template<class T>
static void cpu_ref(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) { void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT && BT) if(AT_ && BT_)
gemm::cpu_ref<T, true, true>(c, a, b, M, N, K); gemm::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
else if(AT && !BT) else if(AT_ && !BT_)
gemm::cpu_ref<T, true, false>(c, a, b, M, N, K); gemm::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
else if(!AT && BT) else if(!AT_ && BT_)
gemm::cpu_ref<T, false, true>(c, a, b, M, N, K); gemm::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
else else
gemm::cpu_ref<T, false, false>(c, a, b, M, N, K); gemm::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
} }
private:
int32_t M_;
int32_t N_;
int32_t K_;
bool AT_;
bool BT_;
std::string a_ty_;
std::string b_ty_;
unsigned align_lda_;
unsigned align_ldb_;
driver::buffer *locks_;
}; };
} }

View File

@@ -46,14 +46,12 @@ public:
}; };
private: private:
// leading dimensions
void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld);
// initialize and enqueue // initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module); void init_impl(driver::stream *stream, driver::cu_module *module);
void enqueue_impl(driver::stream *stream, driver::kernel *kernel, void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args, std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads); const std::vector<unsigned>& ranges,
size_t nthreads);
public: public:
@@ -61,7 +59,7 @@ public:
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int NF, int T, int R, int S, int NF,
int stride_h, int stride_w, int stride_h, int stride_w,
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w, const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32", std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false); type ty = FPROP, bool bias = false);
@@ -74,11 +72,13 @@ public:
size_t c_size(); size_t c_size();
std::vector<int32_t> c_shapes(); std::vector<int32_t> c_shapes();
// number of flops // number of flops
size_t get_nflops() const; size_t num_flops() const;
// source // source
void get_src(std::ostream &os) const; void triton_c_src(std::ostream &os) const;
// comparison // comparison
bool operator<(const base& other) const; bool operator<(const base& other) const;
// clone
base* clone() const;
// cpu reference // cpu reference
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* O, void cpu_ref(OUT_DTYPE* O,
@@ -143,8 +143,8 @@ private:
std::vector<int32_t> ld_b_; std::vector<int32_t> ld_b_;
std::vector<int32_t> ld_c_; std::vector<int32_t> ld_c_;
// shift values // shift values
std::vector<int32_t> shift_h_; const int32_t* shift_h_;
std::vector<int32_t> shift_w_; const int32_t* shift_w_;
// look-up tables // look-up tables
std::vector<int32_t> h_deltas_; std::vector<int32_t> h_deltas_;
std::vector<int32_t> h_masks_; std::vector<int32_t> h_masks_;

View File

@@ -6,38 +6,43 @@
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
struct cmp_recompile{
bool operator()(base* x, base* y) const{
return *x < *y;
} void base::set_ld(const std::vector<int32_t>& shapes,
}; std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
base::base(const std::string& name) base::base(const std::string& name)
: name_(name) { } : name_(name) { }
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) { void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit; static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
bool autotune = false; bool autotune = false;
driver::context* ctx = stream->context(); driver::context* ctx = stream->context();
triton::jit* jit; triton::jit* jit;
/* the current template has not already been compiled */ /* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) { if(m_jit.find(this) == m_jit.end()) {
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get(); jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get();
std::ostringstream oss; std::ostringstream oss;
get_src(oss); triton_c_src(oss);
std::string src = oss.str(); std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) { triton::jit::launch_information info) {
// launch info // launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads; unsigned nthreads = info.num_threads;
init_impl(stream, (triton::driver::cu_module*)kernel->module()); init_impl(stream, (triton::driver::cu_module*)kernel->module());
enqueue_impl(stream, kernel, args, TM, TN, nthreads); enqueue_impl(stream, kernel, args, info.global_range_size, nthreads);
stream->synchronize(); stream->synchronize();
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); }, double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); },
[&](){ stream->synchronize(); }, ctx->device()); [&](){ stream->synchronize(); }, ctx->device());
return get_nflops() / ts * 1e-3; return num_flops() / ts * 1e-3;
}; };
// auto-tune and save result // auto-tune and save result
if(autotune) { if(autotune) {
@@ -57,12 +62,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
/* get launch parameters */ /* get launch parameters */
driver::kernel* kernel = jit->get_function(name_.c_str()); driver::kernel* kernel = jit->get_function(name_.c_str());
triton::jit::launch_information info = jit->get_launch_info(name_.c_str()); triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
/* launch */ /* launch */
enqueue_impl(stream, kernel, args, TM, TN, nthreads); enqueue_impl(stream, kernel, args,
info.global_range_size, info.num_threads);
} }
} }

View File

@@ -29,17 +29,36 @@ namespace dnn{
* Forward * Forward
* --------------- */ * --------------- */
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty) batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) { : base("batchnorm"),
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
DHWB_ = D_*H_*W_*B_; DHWB_ = D_*H_*W_*B_;
rcpDHWB_ = (float)1 / DHWB_; rcpDHWB_ = (float)1 / DHWB_;
} }
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel, size_t batchnorm_forward::num_flops() const {
driver::buffer *y, driver::buffer *m, driver::buffer *v, return C_*DHWB_;
driver::buffer *x, driver::buffer *g, driver::buffer *b, }
size_t, size_t nthreads) {
bool batchnorm_forward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
}
base* batchnorm_forward::clone() const {
return new batchnorm_forward(*this);
}
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>&,
size_t nthreads)
{
driver::buffer *y = args[0], *m = args[1], *v = args[2];
driver::buffer *x = args[3], *g = args[4], *b = args[5];
std::array<size_t, 3> grid = {(size_t)C_, 1, 1}; std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
kernel->setArg(0, y); kernel->setArg(0, y);
kernel->setArg(1, m); kernel->setArg(1, m);
@@ -53,7 +72,7 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
void batchnorm_forward::src(std::ostream &os) { void batchnorm_forward::triton_c_src(std::ostream &os) const {
os << os <<
R"( R"(
const tunable int32 TM = {32, 64, 128}; const tunable int32 TM = {32, 64, 128};
@@ -113,14 +132,32 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
* --------------- */ * --------------- */
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps) batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) : base("batchnorm"),
C_(C), D_(D), H_(H), W_(W), B_(B),
ty_(ty), eps_(eps)
{ } { }
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel, size_t batchnorm_backward::num_flops() const {
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy, return C_*D_*H_*W_*B_;
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v, }
size_t, size_t nthreads) {
bool batchnorm_backward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
}
base* batchnorm_backward::clone() const {
return new batchnorm_backward(*this);
}
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args,
const std::vector<unsigned> &, size_t nthreads) {
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
std::array<size_t, 3> grid = {1, (size_t)C_, 1}; std::array<size_t, 3> grid = {1, (size_t)C_, 1};
kernel->setArg(0, dx); kernel->setArg(0, dx);
kernel->setArg(1, dg); kernel->setArg(1, dg);
@@ -136,7 +173,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
void batchnorm_backward::src(std::ostream &os) { void batchnorm_backward::triton_c_src(std::ostream &os) const {
os << os <<
R"( R"(
const tunable int32 TM = {32, 64, 128}; const tunable int32 TM = {32, 64, 128};

View File

@@ -4,17 +4,6 @@
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
void conv::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[4] = 1;
ld[3] = shapes[4]*ld[4];
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
}
conv::conv(int B, int NC, conv::conv(int B, int NC,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int NF, int T, int R, int S, int NF,
@@ -23,7 +12,8 @@ conv::conv(int B, int NC,
int upsample_d, int upsample_h, int upsample_w, int upsample_d, int upsample_h, int upsample_w,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
type ty, bool bias) type ty, bool bias)
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), : base("conv"),
NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w), stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
@@ -93,7 +83,7 @@ conv::conv(int B, int NC,
Fs_ = K_; Fs_ = K_;
TK_ = 8; TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_; Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
build_deltas(); build_a_deltas();
if(b_lut_) if(b_lut_)
build_b_deltas(); build_b_deltas();
build_masks(); build_masks();
@@ -107,6 +97,28 @@ conv::conv(int B, int NC,
max_grid_1_ = 256; max_grid_1_ = 256;
} }
// comparison for maps
bool conv::operator<(const base& other) const {
auto *y = dynamic_cast<const conv*>(&other);
if(!y)
return true;
return std::tie(NB_, NC_, AD_, AH_, AW_,
NF_, BD_, BH_, BW_,
pad_d_, pad_h_, pad_w_,
stride_d_, stride_h_, stride_w_,
a_ty_, b_ty_, ty_, bias_)
< std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_,
y->NF_, y->BD_, y->BH_, y->BW_,
y->pad_d_, y->pad_h_, y->pad_w_,
y->stride_d_, y->stride_h_, y->stride_w_,
y->a_ty_, y->b_ty_, y->ty_, y->bias_);
}
// clone
base* conv::clone() const {
return new conv(*this);
}
size_t conv::a_size() size_t conv::a_size()
{ return std::accumulate(shapes_a_.begin(), shapes_a_.end(), { return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
1, std::multiplies<int>()); } 1, std::multiplies<int>()); }
@@ -176,7 +188,7 @@ void conv::build_b_deltas(){
} }
} }
void conv::build_deltas(){ void conv::build_a_deltas(){
h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_); h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
for(size_t i = 0; i < Luts_; ++i) for(size_t i = 0; i < Luts_; ++i)
h_a_deltas_[i] = (((i + TK_) % Luts_) - i); h_a_deltas_[i] = (((i + TK_) % Luts_) - i);
@@ -258,13 +270,15 @@ void conv::build_masks(){
h_masks_[i] = 0x0; h_masks_[i] = 0x0;
} }
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN) std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN){
{ return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; } return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
}
size_t conv::get_nflops() size_t conv::num_flops() const{
{ return 2.*M_*N_*K_; } return 2.*M_*N_*K_;
}
void conv::init(driver::stream *stream, triton::driver::cu_module* module) { void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) {
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{ auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
if(host.empty()) if(host.empty())
return nullptr; return nullptr;
@@ -349,9 +363,13 @@ void conv::set_arg(driver::kernel *kernel,
kernel->setArg(idx++, d_masks_); kernel->setArg(idx++, d_masks_);
} }
void conv::enqueue(driver::stream *stream, driver::kernel *kernel, void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias, std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t GZ, size_t nthreads) { const std::vector<unsigned>& ranges,
size_t nthreads) {
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
unsigned TM = ranges[0], TN = ranges[1];
unsigned GZ = 1;
set_arg(kernel, a, b, c, bias); set_arg(kernel, a, b, c, bias);
std::array<size_t, 3> grid = {1}; std::array<size_t, 3> grid = {1};
grid[0] = (M_ + TM - 1)/TM; grid[0] = (M_ + TM - 1)/TM;
@@ -411,6 +429,8 @@ std::vector<unsigned> conv::default_params() {
} }
/* CPU reference implementation */
template<class IN_DTYPE, class OUT_DTYPE> template<class IN_DTYPE, class OUT_DTYPE>
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B) void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{ {
@@ -496,7 +516,9 @@ void conv::cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
cpu_wgrad(C, A, B); cpu_wgrad(C, A, B);
} }
void conv::src(std::ostream &os){ /* Triton-C source code */
void conv::triton_c_src(std::ostream &os) const {
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]"; std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]"; std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]"; std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
@@ -526,7 +548,7 @@ void conv::src(std::ostream &os){
R"( R"(
const tunable int32 TM = {16, 32, 64}; const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64}; const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {16}; const tunable int32 TK = {)" << TK_ << R"(};
const tunable int32 GZ = {1}; const tunable int32 GZ = {1};
)"; )";
if(is_a_deltas_cst) if(is_a_deltas_cst)

View File

@@ -6,64 +6,103 @@
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
gemm::gemm(int M, int N, int K,
bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb)
: base("matmul"),
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
a_ty_(a_ty), b_ty_(b_ty),
align_lda_(alignment_lda), align_ldb_(alignment_ldb),
locks_(nullptr) {
void gemm::init(driver::stream* stream, driver::buffer* locks) {
std::vector<int32_t> hlocks(2048, 0);
stream->write(locks, false, 0, hlocks);
} }
void gemm::set_arg(driver::kernel *kernel, size_t gemm::num_flops() const {
driver::buffer *a, driver::buffer *b, driver::buffer *c, return 2.*M_*N_*K_;
int32_t M, int32_t N, int32_t K, }
driver::buffer *locks, int32_t grid_0, int32_t grid_1) {
// comparison for maps
bool gemm::operator<(const base& other) const {
auto *y = dynamic_cast<const gemm*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_, AT_, BT_,
a_ty_, b_ty_, align_lda_, align_ldb_)
< std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_,
y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_);
}
// clone
base* gemm::clone() const {
return new gemm(*this);
}
void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
std::vector<int32_t> hlocks(2048, 0);
if(locks_ == nullptr)
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
stream->write(locks_, false, 0, hlocks);
}
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads) {
driver::buffer *a = args[0], *b = args[1], *c = args[2];
unsigned TM = ranges[0];
unsigned TN = ranges[1];
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
kernel->setArg(0, a); kernel->setArg(0, a);
kernel->setArg(1, b); kernel->setArg(1, b);
kernel->setArg(2, c); kernel->setArg(2, c);
kernel->setArg(3, M); kernel->setArg(3, M_);
kernel->setArg(4, N); kernel->setArg(4, N_);
kernel->setArg(5, K); kernel->setArg(5, K_);
kernel->setArg(6, M); kernel->setArg(6, M_);
kernel->setArg(7, N); kernel->setArg(7, N_);
kernel->setArg(8, M); kernel->setArg(8, M_);
kernel->setArg(9, locks); kernel->setArg(9, locks_);
kernel->setArg(10, grid_0); kernel->setArg(10, grid_0);
kernel->setArg(11, grid_1); kernel->setArg(11, grid_1);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
std::vector<unsigned> gemm::default_params(bool AT, bool BT) { std::vector<unsigned> gemm::default_params() {
if(AT && BT) if(AT_ && BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1}; return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(AT && !BT) else if(AT_ && !BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1}; return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(!AT && BT) else if(!AT_ && BT_)
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1}; return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
else else
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1}; return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
} }
std::string gemm::src(bool AT, bool BT, void gemm::triton_c_src(std::ostream &os) const {
std::string a_ty, std::string b_ty,
unsigned align_lda, unsigned align_ldb) {
std::string AS0 = "TM", AS1 = "TK"; std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN"; std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = ""; std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb"; std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a"; std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b"; std::string useb = BT_ ? "trans(b)" : "b";
if(AT){ if(AT_){
std::swap(AS0, AS1); std::swap(AS0, AS1);
std::swap(bca0, bca1); std::swap(bca0, bca1);
std::swap(lda0, lda1); std::swap(lda0, lda1);
} }
if(BT){ if(BT_){
std::swap(BS0, BS1); std::swap(BS0, BS1);
std::swap(bcb0, bcb1); std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1); std::swap(ldb0, ldb1);
} }
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res = std::string res =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TM = {16, 32, 64, 128};
@@ -71,8 +110,8 @@ const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8}; const tunable int32 TK = {8};
const tunable int32 GZ = {1}; const tunable int32 GZ = {1};
void matmul(restrict read_only )" + a_ty + R"( *A, void matmul(restrict read_only )" + a_ty_ + R"( *A,
restrict read_only )" + b_ty + R"( *B, restrict read_only )" + b_ty_ + R"( *B,
fp32 *C, fp32 *C,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc, )" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
@@ -87,10 +126,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
int32 rem = K % GZ; int32 rem = K % GZ;
K = select(rz < rem, div - 1, div); K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
)" + a_ty + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; )" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa; )" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
)" + b_ty + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb; )" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK; last_a = last_a / TK * TK;
@@ -108,10 +147,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
for(int32 k = bound; k > 0; k = k - 1){ for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M; int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N; int1 checkb[TN, 1] = ryc[:, newaxis] < N;
)" + a_ty + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(; )" + a_ty_ + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
)" + b_ty + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(; )" + b_ty_ + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
)" + a_ty + R"( a[TM, 1] = checka ? *pa : 0; )" + a_ty_ + R"( a[TM, 1] = checka ? *pa : 0;
)" + b_ty + R"( b[TN, 1] = checkb ? *pb : 0; )" + b_ty_ + R"( b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c); c = dot(a, trans(b), c);
} }
int32 ridx = get_range_id(0); int32 ridx = get_range_id(0);
@@ -136,7 +175,7 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
__atomic_cas(plock, 1, 0); __atomic_cas(plock, 1, 0);
} }
)"; )";
return res; os << res;
} }
} }

View File

@@ -5,21 +5,13 @@
namespace triton{ namespace triton{
namespace dnn{ namespace dnn{
void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
shift::shift(int B, int C, shift::shift(int B, int C,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int T, int R, int S,
int F, int F,
int stride_h, int stride_w, int stride_h, int stride_w,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w, const int32_t *shift_h, const int32_t *shift_w,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
type ty, bool bias) type ty, bool bias)
: base("shift"), : base("shift"),
@@ -73,8 +65,10 @@ shift::shift(int B, int C,
set_ld(shapes_a_, ld_a_); set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_); set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_); set_ld(shapes_c_, ld_c_);
// build LUTs }
build_deltas();
base* shift::clone() const {
return new shift(*this);
} }
void shift::build_deltas() { void shift::build_deltas() {
@@ -120,30 +114,29 @@ std::vector<int32_t> shift::c_shapes(){
return shapes_c_; return shapes_c_;
} }
size_t shift::get_nflops() const { size_t shift::num_flops() const {
return 2.*M_*N_*K_; return 2.*M_*N_*K_;
} }
bool shift::operator <(const base& other) const{ bool shift::operator <(const base& other) const{
auto *y = dynamic_cast<const shift*>(&other); auto *y = dynamic_cast<const shift*>(&other);
if(!y) if(!y)
return false; return true;
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_, return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
x_shift_h, x_shift_w, ty_, bias_) shift_h_, shift_w_, ty_, bias_)
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_, < std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
y_shift_h, y_shift_w, y->ty_, y->bias_); y->shift_h_, y->shift_w_, y->ty_, y->bias_);
} }
void shift::init_impl(driver::stream *stream, driver::cu_module *module) { void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
build_deltas();
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta"); triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data()); stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
} }
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args, std::vector<driver::buffer *> args,
size_t TM, size_t TN, size_t nthreads) { const std::vector<unsigned> &ranges, size_t nthreads) {
int32_t lda = AT_ ? K_ : M_; int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_; int32_t ldb = BT_ ? N_ : K_;
driver::buffer *a = args[0], *b = args[1], *c = args[2]; driver::buffer *a = args[0], *b = args[1], *c = args[2];
@@ -162,13 +155,14 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(12, AW_); kernel->setArg(12, AW_);
kernel->setArg(13, BH_); kernel->setArg(13, BH_);
kernel->setArg(14, BW_); kernel->setArg(14, BW_);
unsigned TM = ranges[0], TN = ranges[1];
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
if(ty_ == BPROP) if(ty_ == BPROP)
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4); ((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
} }
void shift::get_src(std::ostream &os) const { void shift::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK"; std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN"; std::string BS0 = "TK", BS1 = "TN";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";