[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner
This commit is contained in:
@@ -10,7 +10,6 @@
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
// initialization
|
||||
int32_t B = 16, NF = 128;
|
||||
@@ -19,8 +18,12 @@ int main() {
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
|
||||
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
|
||||
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
upsample_d, upsample_h, upsample_w,
|
||||
"fp32", "fp32", ty, 0);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(configuration.c_size());
|
||||
std::vector<float> rc(configuration.c_size());
|
||||
@@ -42,29 +45,7 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, nullptr);
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::ostringstream oss;
|
||||
configuration.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), best.params);
|
||||
// jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
configuration.enqueue(stream, {da, db, dc, nullptr});
|
||||
stream->read(dc, true, 0, hc);
|
||||
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
for(size_t i = 0; i < hc.size(); i++){
|
||||
|
@@ -10,18 +10,14 @@
|
||||
int main() {
|
||||
bool AT = true;
|
||||
bool BT = true;
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 1024, N = 1024, K = 1024;
|
||||
int32_t M = 128, N = 128, K = 128;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
std::vector<float> hb(K*N);
|
||||
std::vector<int32_t> hlocks(2048);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
@@ -32,41 +28,15 @@ int main() {
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
triton::dnn::gemm::init(stream, dlocks);
|
||||
stream->synchronize();
|
||||
|
||||
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, context->device());
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 4, 4);
|
||||
// jit.autotune("matmul",src.c_str(), benchmark);
|
||||
jit.add_module("matmul", src.c_str(), {8, 16, 4, 2, 16, 8, 4, 2, 2, 4, 2, 8, 8, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc});
|
||||
stream->read(dc, true, 0, hc);
|
||||
triton::dnn::gemm::cpu_ref<float>(AT, BT, rc, ha, hb, M, N, K);
|
||||
gemm.cpu_ref<float>(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
|
@@ -29,7 +29,7 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
@@ -53,7 +53,7 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
shift.enqueue(stream, da, db, dc);
|
||||
shift.enqueue(stream, {da, db, dc});
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// for(size_t i = 0; i < hc.size(); i++)
|
||||
|
@@ -72,7 +72,7 @@ torch::Tensor shift_common(
|
||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->get_src(oss);
|
||||
configuration->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given shiftolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
@@ -85,7 +85,7 @@ torch::Tensor shift_common(
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
return configuration->num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
|
@@ -59,15 +59,7 @@ public:
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -135,15 +127,7 @@ public:
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
std::ostringstream oss;
|
||||
batchnorm.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit jit(ctx);
|
||||
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
|
||||
triton::driver::kernel* kernel = jit.get_function("batchnorm");
|
||||
size_t TM = jit.get_int("TM");
|
||||
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
|
||||
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
|
||||
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -50,56 +50,28 @@ public:
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
bool has_bias = false;
|
||||
|
||||
// get conv configuration
|
||||
triton::dnn::conv configuration(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
|
||||
// Bind memory
|
||||
// wrap buffers
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
|
||||
triton::driver::buffer* bias = nullptr;
|
||||
|
||||
// template
|
||||
triton::dnn::conv conv(B, C,
|
||||
D, H, W,
|
||||
T, R, S,
|
||||
NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
// allocate output
|
||||
auto c_shapes = configuration.c_shapes();
|
||||
auto c_shapes = conv.c_shapes();
|
||||
Tensor* tfc = nullptr;
|
||||
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
|
||||
|
||||
// benchmark a given convolution kernel
|
||||
triton::jit jit(ctx);
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
configuration.src(oss);
|
||||
std::string src = oss.str();
|
||||
|
||||
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), best.params);
|
||||
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
// enqueue
|
||||
conv.enqueue(stream, {&a, &b, &c, bias});
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
|
||||
// get inputs
|
||||
const Tensor& a = context->input(0);
|
||||
const Tensor& b = context->input(1);
|
||||
const Tensor& locks = context->input(2);
|
||||
// get shapes
|
||||
const int32_t M = a.dim_size(0);
|
||||
const int32_t N = b.dim_size(0);
|
||||
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
|
||||
// return early if possible
|
||||
if (out_shape.num_elements() == 0)
|
||||
return;
|
||||
// initialize default compute device
|
||||
triton::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
|
||||
// just-in-time compile source-code
|
||||
jit.autotune("matmul", src.c_str(), benchmark);
|
||||
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
||||
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
|
||||
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
// template
|
||||
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
|
||||
REGISTER_OP("Dot")
|
||||
.Input("a: float16")
|
||||
.Input("b: float16")
|
||||
.Input("locks: int32")
|
||||
.Output("c: float32")
|
||||
;
|
||||
|
@@ -127,5 +127,5 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_shift()
|
||||
#run_batchnorm()
|
||||
#run_shift()
|
||||
run_batchnorm()
|
||||
|
@@ -19,15 +19,6 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t*, int32_t*,
|
||||
triton::dnn::shift::type, bool> shift_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
|
||||
|
||||
template<triton::dnn::shift::type OP>
|
||||
class ShiftConvOp : public OpKernel {
|
||||
public:
|
||||
@@ -102,23 +93,15 @@ public:
|
||||
// shift offsets
|
||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
|
||||
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
|
||||
// create configuration
|
||||
triton::dnn::shift* shift;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
shift = m_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h, shift_w,
|
||||
"fp32", "fp32", OP, has_bias))
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h_data, shift_w_data,
|
||||
"fp32", "fp32", OP, has_bias);
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
for(int32_t x: shift->c_shapes())
|
||||
for(int32_t x: shift.c_shapes())
|
||||
c_shapes.push_back(x);
|
||||
TensorShape out_shapes(c_shapes);
|
||||
Tensor* tf_c = nullptr;
|
||||
@@ -130,7 +113,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
shift->enqueue(stream, {&da, &db, &dc});
|
||||
shift.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -29,7 +29,11 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
|
||||
|
||||
class base {
|
||||
friend class cmp_recompile;
|
||||
|
||||
protected:
|
||||
// leading dimensions
|
||||
static void set_ld(const std::vector<int32_t>& shapes,
|
||||
@@ -37,21 +41,24 @@ protected:
|
||||
|
||||
private:
|
||||
// initialize
|
||||
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0;
|
||||
virtual void init_impl(driver::stream *, driver::cu_module *){ }
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
size_t TM, size_t TN, size_t nthreads) = 0;
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) = 0;
|
||||
// number of flops
|
||||
virtual size_t num_flops() const = 0;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const = 0;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
base(const std::string& name);
|
||||
// number of flops
|
||||
virtual size_t get_nflops() const = 0;
|
||||
// triton-c source
|
||||
virtual void get_src(std::ostream &os) const = 0;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const = 0;
|
||||
virtual void triton_c_src(std::ostream &os) const = 0;
|
||||
// clone
|
||||
virtual base* clone() const = 0;
|
||||
// enqueue
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
|
||||
|
||||
@@ -59,6 +66,12 @@ private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
struct cmp_recompile{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return *x < *y;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -28,23 +28,32 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class batchnorm_forward {
|
||||
class batchnorm_forward: public base {
|
||||
private:
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_forward(int C, int D, int H, int W, int B, std::string ty = "fp32");
|
||||
// enqueue
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *b,
|
||||
size_t TM, size_t nthreads);
|
||||
// triton-c source code
|
||||
void src(std::ostream &os);
|
||||
batchnorm_forward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "fp32", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
@@ -58,18 +67,25 @@ private:
|
||||
float rcpDHWB_;
|
||||
};
|
||||
|
||||
class batchnorm_backward {
|
||||
class batchnorm_backward: public base{
|
||||
private:
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5);
|
||||
// enqueue
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t TM, size_t nthreads);
|
||||
// triton-c source code
|
||||
void src(std::ostream &os);
|
||||
|
||||
batchnorm_backward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "fp32", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
private:
|
||||
int32_t C_;
|
||||
|
@@ -4,11 +4,12 @@
|
||||
#include <numeric>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class conv {
|
||||
class conv: public base{
|
||||
public:
|
||||
enum type {
|
||||
FPROP,
|
||||
@@ -17,11 +18,29 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
|
||||
// initialize
|
||||
std::tuple<int32_t, int32_t, int32_t, int32_t>
|
||||
unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW);
|
||||
void build_b_deltas();
|
||||
void build_a_deltas();
|
||||
void build_masks();
|
||||
void init_impl(driver::stream *, driver::cu_module *);
|
||||
|
||||
// enqueue
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
driver::buffer *bias);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
|
||||
@@ -39,35 +58,17 @@ public:
|
||||
size_t b_size();
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
|
||||
// initialize
|
||||
void build_b_deltas();
|
||||
void build_deltas();
|
||||
void build_masks();
|
||||
void init(driver::stream *stream, driver::cu_module *module);
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
driver::buffer *bias);
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
driver::buffer *bias,
|
||||
size_t TM, size_t TN, size_t GZ, size_t nthreads);
|
||||
|
||||
// utilities
|
||||
size_t get_nflops();
|
||||
// default params
|
||||
std::vector<unsigned> default_params();
|
||||
|
||||
// source
|
||||
void src(std::ostream &os);
|
||||
// triton-c source code
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
// cpu check
|
||||
// cpu reference implementations
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
|
||||
|
||||
|
@@ -1,27 +1,44 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/dnn/base.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class gemm {
|
||||
class gemm: public base {
|
||||
private:
|
||||
// initialize
|
||||
void init_impl(driver::stream *, driver::cu_module *);
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
|
||||
|
||||
public:
|
||||
static void init(driver::stream* stream, driver::buffer* locks);
|
||||
gemm(int M, int N, int K, bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb);
|
||||
|
||||
static void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
driver::buffer *locks, int32_t grid_0, int32_t grid_1);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
static std::vector<unsigned> default_params(bool AT, bool BT);
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
static std::string src(bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb);
|
||||
// default params
|
||||
std::vector<unsigned> default_params();
|
||||
|
||||
// CPU reference implementation
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
T acc = 0;
|
||||
@@ -30,18 +47,29 @@ public:
|
||||
c[m + n*M] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static void cpu_ref(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
|
||||
if(AT && BT)
|
||||
gemm::cpu_ref<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT && !BT)
|
||||
gemm::cpu_ref<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT && BT)
|
||||
gemm::cpu_ref<T, false, true>(c, a, b, M, N, K);
|
||||
void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
gemm::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
|
||||
else if(AT_ && !BT_)
|
||||
gemm::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
|
||||
else if(!AT_ && BT_)
|
||||
gemm::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
|
||||
else
|
||||
gemm::cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
gemm::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
bool AT_;
|
||||
bool BT_;
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
unsigned align_lda_;
|
||||
unsigned align_ldb_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -46,14 +46,12 @@ public:
|
||||
};
|
||||
|
||||
private:
|
||||
// leading dimensions
|
||||
void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
size_t TM, size_t TN, size_t nthreads);
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads);
|
||||
|
||||
public:
|
||||
|
||||
@@ -61,7 +59,7 @@ public:
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int stride_h, int stride_w,
|
||||
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w,
|
||||
const int32_t* shift_h, const int32_t* shift_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
type ty = FPROP, bool bias = false);
|
||||
|
||||
@@ -74,11 +72,13 @@ public:
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
// number of flops
|
||||
size_t get_nflops() const;
|
||||
size_t num_flops() const;
|
||||
// source
|
||||
void get_src(std::ostream &os) const;
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// comparison
|
||||
bool operator<(const base& other) const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
// cpu reference
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* O,
|
||||
@@ -143,8 +143,8 @@ private:
|
||||
std::vector<int32_t> ld_b_;
|
||||
std::vector<int32_t> ld_c_;
|
||||
// shift values
|
||||
std::vector<int32_t> shift_h_;
|
||||
std::vector<int32_t> shift_w_;
|
||||
const int32_t* shift_h_;
|
||||
const int32_t* shift_w_;
|
||||
// look-up tables
|
||||
std::vector<int32_t> h_deltas_;
|
||||
std::vector<int32_t> h_masks_;
|
||||
|
@@ -6,38 +6,43 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
struct cmp_recompile{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return *x < *y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
|
||||
bool autotune = false;
|
||||
driver::context* ctx = stream->context();
|
||||
triton::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
|
||||
jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
get_src(oss);
|
||||
triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
enqueue_impl(stream, kernel, args, info.global_range_size, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return get_nflops() / ts * 1e-3;
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
@@ -57,12 +62,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
|
||||
/* launch */
|
||||
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
|
||||
enqueue_impl(stream, kernel, args,
|
||||
info.global_range_size, info.num_threads);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -29,17 +29,36 @@ namespace dnn{
|
||||
* Forward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *b,
|
||||
size_t, size_t nthreads) {
|
||||
size_t batchnorm_forward::num_flops() const {
|
||||
return C_*DHWB_;
|
||||
}
|
||||
|
||||
bool batchnorm_forward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
}
|
||||
|
||||
base* batchnorm_forward::clone() const {
|
||||
return new batchnorm_forward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>&,
|
||||
size_t nthreads)
|
||||
{
|
||||
driver::buffer *y = args[0], *m = args[1], *v = args[2];
|
||||
driver::buffer *x = args[3], *g = args[4], *b = args[5];
|
||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
||||
kernel->setArg(0, y);
|
||||
kernel->setArg(1, m);
|
||||
@@ -53,7 +72,7 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_forward::src(std::ostream &os) {
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
@@ -113,14 +132,32 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps)
|
||||
: base("batchnorm"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B),
|
||||
ty_(ty), eps_(eps)
|
||||
{ }
|
||||
|
||||
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t, size_t nthreads) {
|
||||
size_t batchnorm_backward::num_flops() const {
|
||||
return C_*D_*H_*W_*B_;
|
||||
}
|
||||
|
||||
bool batchnorm_backward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
}
|
||||
|
||||
base* batchnorm_backward::clone() const {
|
||||
return new batchnorm_backward(*this);
|
||||
}
|
||||
|
||||
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
const std::vector<unsigned> &, size_t nthreads) {
|
||||
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
|
||||
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
kernel->setArg(0, dx);
|
||||
kernel->setArg(1, dg);
|
||||
@@ -136,7 +173,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void batchnorm_backward::src(std::ostream &os) {
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
|
@@ -4,17 +4,6 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
void conv::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[4] = 1;
|
||||
ld[3] = shapes[4]*ld[4];
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
}
|
||||
|
||||
conv::conv(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
@@ -23,7 +12,8 @@ conv::conv(int B, int NC,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
: base("conv"),
|
||||
NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
@@ -93,7 +83,7 @@ conv::conv(int B, int NC,
|
||||
Fs_ = K_;
|
||||
TK_ = 8;
|
||||
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
|
||||
build_deltas();
|
||||
build_a_deltas();
|
||||
if(b_lut_)
|
||||
build_b_deltas();
|
||||
build_masks();
|
||||
@@ -107,6 +97,28 @@ conv::conv(int B, int NC,
|
||||
max_grid_1_ = 256;
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
bool conv::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const conv*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(NB_, NC_, AD_, AH_, AW_,
|
||||
NF_, BD_, BH_, BW_,
|
||||
pad_d_, pad_h_, pad_w_,
|
||||
stride_d_, stride_h_, stride_w_,
|
||||
a_ty_, b_ty_, ty_, bias_)
|
||||
< std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_,
|
||||
y->NF_, y->BD_, y->BH_, y->BW_,
|
||||
y->pad_d_, y->pad_h_, y->pad_w_,
|
||||
y->stride_d_, y->stride_h_, y->stride_w_,
|
||||
y->a_ty_, y->b_ty_, y->ty_, y->bias_);
|
||||
}
|
||||
|
||||
// clone
|
||||
base* conv::clone() const {
|
||||
return new conv(*this);
|
||||
}
|
||||
|
||||
size_t conv::a_size()
|
||||
{ return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
|
||||
1, std::multiplies<int>()); }
|
||||
@@ -176,7 +188,7 @@ void conv::build_b_deltas(){
|
||||
}
|
||||
}
|
||||
|
||||
void conv::build_deltas(){
|
||||
void conv::build_a_deltas(){
|
||||
h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
|
||||
for(size_t i = 0; i < Luts_; ++i)
|
||||
h_a_deltas_[i] = (((i + TK_) % Luts_) - i);
|
||||
@@ -258,13 +270,15 @@ void conv::build_masks(){
|
||||
h_masks_[i] = 0x0;
|
||||
}
|
||||
|
||||
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN)
|
||||
{ return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; }
|
||||
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN){
|
||||
return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
}
|
||||
|
||||
size_t conv::get_nflops()
|
||||
{ return 2.*M_*N_*K_; }
|
||||
size_t conv::num_flops() const{
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
|
||||
void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) {
|
||||
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
|
||||
if(host.empty())
|
||||
return nullptr;
|
||||
@@ -349,9 +363,13 @@ void conv::set_arg(driver::kernel *kernel,
|
||||
kernel->setArg(idx++, d_masks_);
|
||||
}
|
||||
|
||||
void conv::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias,
|
||||
size_t TM, size_t TN, size_t GZ, size_t nthreads) {
|
||||
void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
unsigned GZ = 1;
|
||||
set_arg(kernel, a, b, c, bias);
|
||||
std::array<size_t, 3> grid = {1};
|
||||
grid[0] = (M_ + TM - 1)/TM;
|
||||
@@ -411,6 +429,8 @@ std::vector<unsigned> conv::default_params() {
|
||||
}
|
||||
|
||||
|
||||
/* CPU reference implementation */
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
@@ -496,7 +516,9 @@ void conv::cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
cpu_wgrad(C, A, B);
|
||||
}
|
||||
|
||||
void conv::src(std::ostream &os){
|
||||
/* Triton-C source code */
|
||||
|
||||
void conv::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
@@ -526,7 +548,7 @@ void conv::src(std::ostream &os){
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {16};
|
||||
const tunable int32 TK = {)" << TK_ << R"(};
|
||||
const tunable int32 GZ = {1};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
|
115
lib/dnn/gemm.cpp
115
lib/dnn/gemm.cpp
@@ -6,64 +6,103 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
gemm::gemm(int M, int N, int K,
|
||||
bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb)
|
||||
: base("matmul"),
|
||||
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
align_lda_(alignment_lda), align_ldb_(alignment_ldb),
|
||||
locks_(nullptr) {
|
||||
|
||||
void gemm::init(driver::stream* stream, driver::buffer* locks) {
|
||||
std::vector<int32_t> hlocks(2048, 0);
|
||||
stream->write(locks, false, 0, hlocks);
|
||||
}
|
||||
|
||||
void gemm::set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
driver::buffer *locks, int32_t grid_0, int32_t grid_1) {
|
||||
size_t gemm::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
bool gemm::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const gemm*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(M_, N_, K_, AT_, BT_,
|
||||
a_ty_, b_ty_, align_lda_, align_ldb_)
|
||||
< std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_,
|
||||
y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_);
|
||||
}
|
||||
|
||||
// clone
|
||||
base* gemm::clone() const {
|
||||
return new gemm(*this);
|
||||
}
|
||||
|
||||
void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
|
||||
std::vector<int32_t> hlocks(2048, 0);
|
||||
if(locks_ == nullptr)
|
||||
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
|
||||
stream->write(locks_, false, 0, hlocks);
|
||||
}
|
||||
|
||||
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
const std::vector<unsigned>& ranges,
|
||||
size_t nthreads) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
unsigned TM = ranges[0];
|
||||
unsigned TN = ranges[1];
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned grid_2 = 1;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, M);
|
||||
kernel->setArg(7, N);
|
||||
kernel->setArg(8, M);
|
||||
kernel->setArg(9, locks);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, M_);
|
||||
kernel->setArg(9, locks_);
|
||||
kernel->setArg(10, grid_0);
|
||||
kernel->setArg(11, grid_1);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
std::vector<unsigned> gemm::default_params(bool AT, bool BT) {
|
||||
if(AT && BT)
|
||||
std::vector<unsigned> gemm::default_params() {
|
||||
if(AT_ && BT_)
|
||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
else if(AT && !BT)
|
||||
else if(AT_ && !BT_)
|
||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
else if(!AT && BT)
|
||||
else if(!AT_ && BT_)
|
||||
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
|
||||
else
|
||||
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
|
||||
}
|
||||
|
||||
std::string gemm::src(bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned align_lda, unsigned align_ldb) {
|
||||
void gemm::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::string usea = AT_ ? "trans(a)" : "a";
|
||||
std::string useb = BT_ ? "trans(b)" : "b";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -71,8 +110,8 @@ const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only )" + a_ty + R"( *A,
|
||||
restrict read_only )" + b_ty + R"( *B,
|
||||
void matmul(restrict read_only )" + a_ty_ + R"( *A,
|
||||
restrict read_only )" + b_ty_ + R"( *B,
|
||||
fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
|
||||
@@ -87,10 +126,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
)" + a_ty + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
||||
)" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
||||
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
@@ -108,10 +147,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
)" + a_ty + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[TM, 1] = checka ? *pa : 0;
|
||||
)" + b_ty + R"( b[TN, 1] = checkb ? *pb : 0;
|
||||
)" + a_ty_ + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
|
||||
)" + a_ty_ + R"( a[TM, 1] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
@@ -136,7 +175,7 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
|
||||
__atomic_cas(plock, 1, 0);
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
os << res;
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -5,21 +5,13 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
void shift::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[size - 1] = 1;
|
||||
for(int i = size - 1; i >= 1; i--)
|
||||
ld[i - 1] = shapes[i] * ld[i];
|
||||
}
|
||||
|
||||
shift::shift(int B, int C,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int F,
|
||||
int stride_h, int stride_w,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
const int32_t *shift_h, const int32_t *shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: base("shift"),
|
||||
@@ -73,8 +65,10 @@ shift::shift(int B, int C,
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
set_ld(shapes_b_, ld_b_);
|
||||
set_ld(shapes_c_, ld_c_);
|
||||
// build LUTs
|
||||
build_deltas();
|
||||
}
|
||||
|
||||
base* shift::clone() const {
|
||||
return new shift(*this);
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
@@ -120,30 +114,29 @@ std::vector<int32_t> shift::c_shapes(){
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() const {
|
||||
size_t shift::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
bool shift::operator <(const base& other) const{
|
||||
auto *y = dynamic_cast<const shift*>(&other);
|
||||
if(!y)
|
||||
return false;
|
||||
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
|
||||
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
|
||||
return true;
|
||||
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
|
||||
x_shift_h, x_shift_w, ty_, bias_)
|
||||
shift_h_, shift_w_, ty_, bias_)
|
||||
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
|
||||
y_shift_h, y_shift_w, y->ty_, y->bias_);
|
||||
y->shift_h_, y->shift_w_, y->ty_, y->bias_);
|
||||
}
|
||||
|
||||
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
build_deltas();
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
||||
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
||||
}
|
||||
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
const std::vector<unsigned> &ranges, size_t nthreads) {
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
@@ -162,13 +155,14 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(12, AW_);
|
||||
kernel->setArg(13, BH_);
|
||||
kernel->setArg(14, BW_);
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(ty_ == BPROP)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::get_src(std::ostream &os) const {
|
||||
void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
|
Reference in New Issue
Block a user