[dnn]: Now implementing all existing DNN routines using common base template and auto-tuner

This commit is contained in:
Philippe Tillet
2019-07-09 17:30:58 -07:00
parent 88675fa01a
commit b7986baffa
19 changed files with 388 additions and 375 deletions

View File

@@ -10,7 +10,6 @@
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 16, NF = 128;
@@ -19,8 +18,12 @@ int main() {
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
// triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, ty);
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
upsample_d, upsample_h, upsample_w,
"fp32", "fp32", ty, 0);
// convolution configuration
std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size());
@@ -42,29 +45,7 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, da, db, dc, nullptr, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, nullptr);
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
configuration.enqueue(stream, {da, db, dc, nullptr});
stream->read(dc, true, 0, hc);
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++){

View File

@@ -10,18 +10,14 @@
int main() {
bool AT = true;
bool BT = true;
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 1024, N = 1024, K = 1024;
int32_t M = 128, N = 128, K = 128;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
std::vector<float> hb(K*N);
std::vector<int32_t> hlocks(2048);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
@@ -32,41 +28,15 @@ int main() {
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
triton::dnn::gemm::init(stream, dlocks);
stream->synchronize();
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, context->device());
return 2.*M*N*K / ts * 1e-3;
};
// just-in-time compile source-code
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 4, 4);
// jit.autotune("matmul",src.c_str(), benchmark);
jit.add_module("matmul", src.c_str(), {8, 16, 4, 2, 16, 8, 4, 2, 2, 4, 2, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
gemm.enqueue(stream, {da, db, dc});
stream->read(dc, true, 0, hc);
triton::dnn::gemm::cpu_ref<float>(AT, BT, rc, ha, hb, M, N, K);
gemm.cpu_ref<float>(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;

View File

@@ -29,7 +29,7 @@ int main() {
shift_w[c] = rand() % S - S/2;
}
// configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
// host buffers
std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size());
@@ -53,7 +53,7 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
shift.enqueue(stream, da, db, dc);
shift.enqueue(stream, {da, db, dc});
// stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++)

View File

@@ -72,7 +72,7 @@ torch::Tensor shift_common(
if(m_shift_jit.find(key) == m_shift_jit.end()){
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
configuration->get_src(oss);
configuration->triton_c_src(oss);
std::string src = oss.str();
// benchmark a given shiftolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
@@ -85,7 +85,7 @@ torch::Tensor shift_common(
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration->get_nflops() / ts * 1e-3;
return configuration->num_flops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {

View File

@@ -59,15 +59,7 @@ public:
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &ty, &tm, &tv, &tx, &tg, &tb, TM, info.num_threads);
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
}
private:
@@ -135,15 +127,7 @@ public:
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
std::ostringstream oss;
batchnorm.src(oss);
std::string src = oss.str();
triton::jit jit(ctx);
jit.add_module("batchnorm", src.c_str(), jit.get_valid("batchnorm", src.c_str()));
triton::driver::kernel* kernel = jit.get_function("batchnorm");
size_t TM = jit.get_int("TM");
triton::jit::launch_information info = jit.get_launch_info("batchnorm");
batchnorm.enqueue(stream, kernel, &tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv, TM, info.num_threads);
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
}
private:

View File

@@ -50,56 +50,28 @@ public:
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
bool has_bias = false;
// get conv configuration
triton::dnn::conv configuration(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// Bind memory
// wrap buffers
triton::driver::cu_buffer a(ctx, (CUdeviceptr)tfa.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)tfb.flat<Eigen::half>().data(), false);
triton::driver::buffer* bias = nullptr;
// template
triton::dnn::conv conv(B, C,
D, H, W,
T, R, S,
NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w,
1, 1, 1,
"fp16", "fp16",
triton::dnn::conv::FPROP, has_bias);
// allocate output
auto c_shapes = configuration.c_shapes();
auto c_shapes = conv.c_shapes();
Tensor* tfc = nullptr;
TensorShape out_shape({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &tfc));
triton::driver::cu_buffer c(ctx, (CUdeviceptr)tfc->flat<float>().data(), false);
// benchmark a given convolution kernel
triton::jit jit(ctx);
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration.enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration.get_nflops() / ts * 1e-3;
};
std::ostringstream oss;
configuration.src(oss);
std::string src = oss.str();
triton::jit::tune_res_t best = jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), best.params);
// jit.add_module("conv", src.c_str(), {16, 2, 32, 32, 2, 64, 2, 2, 2, 2, 8, 2, 16, 4, 1});
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
std::cout << benchmark(kernel, info) << std::endl;
// enqueue
conv.enqueue(stream, {&a, &b, &c, bias});
}
};

View File

@@ -33,7 +33,6 @@ class DotOp : public OpKernel {
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes
const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0);
@@ -45,40 +44,13 @@ class DotOp : public OpKernel {
// return early if possible
if (out_shape.num_elements() == 0)
return;
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
triton::dnn::gemm::set_arg(kernel, &da, &db, &dc, M, N, K, &dlocks, grid[0], grid[1]);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream->synchronize(); }, ctx->device());
return 2.*M*N*K / ts * 1e-3;
};
std::string src = triton::dnn::gemm::src(false, true, "fp16", "fp16", 1, 1);
// just-in-time compile source-code
jit.autotune("matmul", src.c_str(), benchmark);
// jit.add_module("matmul", src.c_str(), {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src.c_str(), {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
// jit.add_module("matmul", src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 2, 2, 2, 2, 8, 8, 4, 2, 1}); //NN
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
std::cout << benchmark(kernel, info) << std::endl;
// template
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
dot.enqueue(stream, {&da, &db, &dc});
}
private:
@@ -88,6 +60,5 @@ REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), DotOp);
REGISTER_OP("Dot")
.Input("a: float16")
.Input("b: float16")
.Input("locks: int32")
.Output("c: float32")
;

View File

@@ -127,5 +127,5 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
run_shift()
#run_batchnorm()
#run_shift()
run_batchnorm()

View File

@@ -19,15 +19,6 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t*, int32_t*,
triton::dnn::shift::type, bool> shift_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
template<triton::dnn::shift::type OP>
class ShiftConvOp : public OpKernel {
public:
@@ -102,23 +93,15 @@ public:
// shift offsets
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
// create configuration
triton::dnn::shift* shift;
if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h, shift_w,
"fp32", "fp32", OP, has_bias))
.first->second.get();
else
shift = m_config.at(key).get();
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_,
shift_h_data, shift_w_data,
"fp32", "fp32", OP, has_bias);
// shapes for c
std::vector<int64> c_shapes;
for(int32_t x: shift->c_shapes())
for(int32_t x: shift.c_shapes())
c_shapes.push_back(x);
TensorShape out_shapes(c_shapes);
Tensor* tf_c = nullptr;
@@ -130,7 +113,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
shift->enqueue(stream, {&da, &db, &dc});
shift.enqueue(stream, {&da, &db, &dc});
}
private:

View File

@@ -29,7 +29,11 @@
namespace triton{
namespace dnn{
class base {
friend class cmp_recompile;
protected:
// leading dimensions
static void set_ld(const std::vector<int32_t>& shapes,
@@ -37,21 +41,24 @@ protected:
private:
// initialize
virtual void init_impl(driver::stream *stream, driver::cu_module *module) = 0;
virtual void init_impl(driver::stream *, driver::cu_module *){ }
// enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads) = 0;
const std::vector<unsigned>& ranges,
size_t nthreads) = 0;
// number of flops
virtual size_t num_flops() const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
public:
// constructor
base(const std::string& name);
// number of flops
virtual size_t get_nflops() const = 0;
// triton-c source
virtual void get_src(std::ostream &os) const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
virtual void triton_c_src(std::ostream &os) const = 0;
// clone
virtual base* clone() const = 0;
// enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args);
@@ -59,6 +66,12 @@ private:
std::string name_;
};
struct cmp_recompile{
bool operator()(base* x, base* y) const{
return *x < *y;
}
};
}
}

View File

@@ -28,23 +28,32 @@
#include <algorithm>
#include <numeric>
#include <cmath>
#include "triton/dnn/base.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton{
namespace dnn{
class batchnorm_forward {
class batchnorm_forward: public base {
private:
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned> &ranges, size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public:
// constructor
batchnorm_forward(int C, int D, int H, int W, int B, std::string ty = "fp32");
// enqueue
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *y, driver::buffer *m, driver::buffer *v,
driver::buffer *x, driver::buffer *g, driver::buffer *b,
size_t TM, size_t nthreads);
// triton-c source code
void src(std::ostream &os);
batchnorm_forward(int C, int D, int H, int W, int B,
std::string ty = "fp32", float eps = 1e-5);
// triton-c source
void triton_c_src(std::ostream &os) const;
private:
int32_t C_;
@@ -58,18 +67,25 @@ private:
float rcpDHWB_;
};
class batchnorm_backward {
class batchnorm_backward: public base{
private:
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned> &ranges, size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public:
// constructor
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5);
// enqueue
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
size_t TM, size_t nthreads);
// triton-c source code
void src(std::ostream &os);
batchnorm_backward(int C, int D, int H, int W, int B,
std::string ty = "fp32", float eps = 1e-5);
// triton-c source
void triton_c_src(std::ostream &os) const;
private:
int32_t C_;

View File

@@ -4,11 +4,12 @@
#include <numeric>
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
namespace triton{
namespace dnn{
class conv {
class conv: public base{
public:
enum type {
FPROP,
@@ -17,11 +18,29 @@ public:
};
private:
void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld);
// initialize
std::tuple<int32_t, int32_t, int32_t, int32_t>
unpack(int32_t ltrs, bool flip, int32_t EBD, int32_t EBH, int32_t EBW);
void build_b_deltas();
void build_a_deltas();
void build_masks();
void init_impl(driver::stream *, driver::cu_module *);
// enqueue
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias);
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// clone
base* clone() const;
public:
@@ -39,35 +58,17 @@ public:
size_t b_size();
size_t c_size();
std::vector<int32_t> c_shapes();
// initialize
void build_b_deltas();
void build_deltas();
void build_masks();
void init(driver::stream *stream, driver::cu_module *module);
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias);
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
driver::buffer *bias,
size_t TM, size_t TN, size_t GZ, size_t nthreads);
// utilities
size_t get_nflops();
// default params
std::vector<unsigned> default_params();
// source
void src(std::ostream &os);
// triton-c source code
void triton_c_src(std::ostream &os) const;
// cpu check
// cpu reference implementations
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_wgrad(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B);

View File

@@ -1,27 +1,44 @@
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/dnn/base.h"
#include <string>
namespace triton{
namespace dnn{
class gemm {
class gemm: public base {
private:
// initialize
void init_impl(driver::stream *, driver::cu_module *);
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads);
// number of flops
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
public:
static void init(driver::stream* stream, driver::buffer* locks);
gemm(int M, int N, int K, bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb);
static void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
int32_t M, int32_t N, int32_t K,
driver::buffer *locks, int32_t grid_0, int32_t grid_1);
// triton-c source
void triton_c_src(std::ostream &os) const;
static std::vector<unsigned> default_params(bool AT, bool BT);
// clone
base* clone() const;
static std::string src(bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb);
// default params
std::vector<unsigned> default_params();
// CPU reference implementation
template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
T acc = 0;
@@ -30,18 +47,29 @@ public:
c[m + n*M] = acc;
}
}
template<class T>
static void cpu_ref(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
if(AT && BT)
gemm::cpu_ref<T, true, true>(c, a, b, M, N, K);
else if(AT && !BT)
gemm::cpu_ref<T, true, false>(c, a, b, M, N, K);
else if(!AT && BT)
gemm::cpu_ref<T, false, true>(c, a, b, M, N, K);
void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT_ && BT_)
gemm::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
else if(AT_ && !BT_)
gemm::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
else if(!AT_ && BT_)
gemm::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
else
gemm::cpu_ref<T, false, false>(c, a, b, M, N, K);
gemm::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
}
private:
int32_t M_;
int32_t N_;
int32_t K_;
bool AT_;
bool BT_;
std::string a_ty_;
std::string b_ty_;
unsigned align_lda_;
unsigned align_ldb_;
driver::buffer *locks_;
};
}

View File

@@ -46,14 +46,12 @@ public:
};
private:
// leading dimensions
void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld);
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
size_t TM, size_t TN, size_t nthreads);
const std::vector<unsigned>& ranges,
size_t nthreads);
public:
@@ -61,7 +59,7 @@ public:
int D, int H, int W,
int T, int R, int S, int NF,
int stride_h, int stride_w,
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w,
const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false);
@@ -74,11 +72,13 @@ public:
size_t c_size();
std::vector<int32_t> c_shapes();
// number of flops
size_t get_nflops() const;
size_t num_flops() const;
// source
void get_src(std::ostream &os) const;
void triton_c_src(std::ostream &os) const;
// comparison
bool operator<(const base& other) const;
// clone
base* clone() const;
// cpu reference
template<class IN_DTYPE, class OUT_DTYPE>
void cpu_ref(OUT_DTYPE* O,
@@ -143,8 +143,8 @@ private:
std::vector<int32_t> ld_b_;
std::vector<int32_t> ld_c_;
// shift values
std::vector<int32_t> shift_h_;
std::vector<int32_t> shift_w_;
const int32_t* shift_h_;
const int32_t* shift_w_;
// look-up tables
std::vector<int32_t> h_deltas_;
std::vector<int32_t> h_masks_;

View File

@@ -6,38 +6,43 @@
namespace triton{
namespace dnn{
struct cmp_recompile{
bool operator()(base* x, base* y) const{
return *x < *y;
}
};
void base::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
base::base(const std::string& name)
: name_(name) { }
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
static std::map<base*, std::unique_ptr<triton::jit>, cmp_recompile> m_jit;
bool autotune = false;
driver::context* ctx = stream->context();
triton::jit* jit;
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
jit = m_jit.emplace(this, new triton::jit(ctx)).first->second.get();
jit = m_jit.emplace(this->clone(), new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
get_src(oss);
triton_c_src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
init_impl(stream, (triton::driver::cu_module*)kernel->module());
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
enqueue_impl(stream, kernel, args, info.global_range_size, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, TM, TN, nthreads); },
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info.global_range_size, nthreads); },
[&](){ stream->synchronize(); }, ctx->device());
return get_nflops() / ts * 1e-3;
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
@@ -57,12 +62,9 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args) {
/* get launch parameters */
driver::kernel* kernel = jit->get_function(name_.c_str());
triton::jit::launch_information info = jit->get_launch_info(name_.c_str());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
/* launch */
enqueue_impl(stream, kernel, args, TM, TN, nthreads);
enqueue_impl(stream, kernel, args,
info.global_range_size, info.num_threads);
}
}

View File

@@ -29,17 +29,36 @@ namespace dnn{
* Forward
* --------------- */
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
: base("batchnorm"),
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
DHWB_ = D_*H_*W_*B_;
rcpDHWB_ = (float)1 / DHWB_;
}
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *y, driver::buffer *m, driver::buffer *v,
driver::buffer *x, driver::buffer *g, driver::buffer *b,
size_t, size_t nthreads) {
size_t batchnorm_forward::num_flops() const {
return C_*DHWB_;
}
bool batchnorm_forward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
}
base* batchnorm_forward::clone() const {
return new batchnorm_forward(*this);
}
void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>&,
size_t nthreads)
{
driver::buffer *y = args[0], *m = args[1], *v = args[2];
driver::buffer *x = args[3], *g = args[4], *b = args[5];
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
kernel->setArg(0, y);
kernel->setArg(1, m);
@@ -53,7 +72,7 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
void batchnorm_forward::src(std::ostream &os) {
void batchnorm_forward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int32 TM = {32, 64, 128};
@@ -113,14 +132,32 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
* --------------- */
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps)
: base("batchnorm"),
C_(C), D_(D), H_(H), W_(W), B_(B),
ty_(ty), eps_(eps)
{ }
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
size_t, size_t nthreads) {
size_t batchnorm_backward::num_flops() const {
return C_*D_*H_*W_*B_;
}
bool batchnorm_backward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
}
base* batchnorm_backward::clone() const {
return new batchnorm_backward(*this);
}
void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args,
const std::vector<unsigned> &, size_t nthreads) {
driver::buffer *dx = args[0], *dg = args[1], *db = args[2], *dy = args[3];
driver::buffer *x = args[4], *g = args[5], *m = args[6], *v = args[7];
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
kernel->setArg(0, dx);
kernel->setArg(1, dg);
@@ -136,7 +173,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
void batchnorm_backward::src(std::ostream &os) {
void batchnorm_backward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int32 TM = {32, 64, 128};

View File

@@ -4,17 +4,6 @@
namespace triton{
namespace dnn{
void conv::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[4] = 1;
ld[3] = shapes[4]*ld[4];
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
}
conv::conv(int B, int NC,
int D, int H, int W,
int T, int R, int S, int NF,
@@ -23,7 +12,8 @@ conv::conv(int B, int NC,
int upsample_d, int upsample_h, int upsample_w,
std::string a_ty, std::string b_ty,
type ty, bool bias)
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
: base("conv"),
NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
@@ -93,7 +83,7 @@ conv::conv(int B, int NC,
Fs_ = K_;
TK_ = 8;
Luts_ = (TK_ + Fs_ - 1) / Fs_ * Fs_;
build_deltas();
build_a_deltas();
if(b_lut_)
build_b_deltas();
build_masks();
@@ -107,6 +97,28 @@ conv::conv(int B, int NC,
max_grid_1_ = 256;
}
// comparison for maps
bool conv::operator<(const base& other) const {
auto *y = dynamic_cast<const conv*>(&other);
if(!y)
return true;
return std::tie(NB_, NC_, AD_, AH_, AW_,
NF_, BD_, BH_, BW_,
pad_d_, pad_h_, pad_w_,
stride_d_, stride_h_, stride_w_,
a_ty_, b_ty_, ty_, bias_)
< std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_,
y->NF_, y->BD_, y->BH_, y->BW_,
y->pad_d_, y->pad_h_, y->pad_w_,
y->stride_d_, y->stride_h_, y->stride_w_,
y->a_ty_, y->b_ty_, y->ty_, y->bias_);
}
// clone
base* conv::clone() const {
return new conv(*this);
}
size_t conv::a_size()
{ return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
1, std::multiplies<int>()); }
@@ -176,7 +188,7 @@ void conv::build_b_deltas(){
}
}
void conv::build_deltas(){
void conv::build_a_deltas(){
h_a_deltas_.resize(Luts_ + upsample_d_*upsample_h_*upsample_w_*Luts_);
for(size_t i = 0; i < Luts_; ++i)
h_a_deltas_[i] = (((i + TK_) % Luts_) - i);
@@ -258,13 +270,15 @@ void conv::build_masks(){
h_masks_[i] = 0x0;
}
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN)
{ return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; }
std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN){
return {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
}
size_t conv::get_nflops()
{ return 2.*M_*N_*K_; }
size_t conv::num_flops() const{
return 2.*M_*N_*K_;
}
void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
void conv::init_impl(driver::stream *stream, triton::driver::cu_module* module) {
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
if(host.empty())
return nullptr;
@@ -349,9 +363,13 @@ void conv::set_arg(driver::kernel *kernel,
kernel->setArg(idx++, d_masks_);
}
void conv::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c, driver::buffer *bias,
size_t TM, size_t TN, size_t GZ, size_t nthreads) {
void conv::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads) {
driver::buffer *a = args[0], *b = args[1], *c = args[2], *bias = args[3];
unsigned TM = ranges[0], TN = ranges[1];
unsigned GZ = 1;
set_arg(kernel, a, b, c, bias);
std::array<size_t, 3> grid = {1};
grid[0] = (M_ + TM - 1)/TM;
@@ -411,6 +429,8 @@ std::vector<unsigned> conv::default_params() {
}
/* CPU reference implementation */
template<class IN_DTYPE, class OUT_DTYPE>
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{
@@ -496,7 +516,9 @@ void conv::cpu_ref(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
cpu_wgrad(C, A, B);
}
void conv::src(std::ostream &os){
/* Triton-C source code */
void conv::triton_c_src(std::ostream &os) const {
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
@@ -526,7 +548,7 @@ void conv::src(std::ostream &os){
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {16};
const tunable int32 TK = {)" << TK_ << R"(};
const tunable int32 GZ = {1};
)";
if(is_a_deltas_cst)

View File

@@ -6,64 +6,103 @@
namespace triton{
namespace dnn{
gemm::gemm(int M, int N, int K,
bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb)
: base("matmul"),
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
a_ty_(a_ty), b_ty_(b_ty),
align_lda_(alignment_lda), align_ldb_(alignment_ldb),
locks_(nullptr) {
void gemm::init(driver::stream* stream, driver::buffer* locks) {
std::vector<int32_t> hlocks(2048, 0);
stream->write(locks, false, 0, hlocks);
}
void gemm::set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
int32_t M, int32_t N, int32_t K,
driver::buffer *locks, int32_t grid_0, int32_t grid_1) {
size_t gemm::num_flops() const {
return 2.*M_*N_*K_;
}
// comparison for maps
bool gemm::operator<(const base& other) const {
auto *y = dynamic_cast<const gemm*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_, AT_, BT_,
a_ty_, b_ty_, align_lda_, align_ldb_)
< std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_,
y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_);
}
// clone
base* gemm::clone() const {
return new gemm(*this);
}
void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
std::vector<int32_t> hlocks(2048, 0);
if(locks_ == nullptr)
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
stream->write(locks_, false, 0, hlocks);
}
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
const std::vector<unsigned>& ranges,
size_t nthreads) {
driver::buffer *a = args[0], *b = args[1], *c = args[2];
unsigned TM = ranges[0];
unsigned TN = ranges[1];
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, locks);
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, M_);
kernel->setArg(7, N_);
kernel->setArg(8, M_);
kernel->setArg(9, locks_);
kernel->setArg(10, grid_0);
kernel->setArg(11, grid_1);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
std::vector<unsigned> gemm::default_params(bool AT, bool BT) {
if(AT && BT)
std::vector<unsigned> gemm::default_params() {
if(AT_ && BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(AT && !BT)
else if(AT_ && !BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(!AT && BT)
else if(!AT_ && BT_)
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
else
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
}
std::string gemm::src(bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned align_lda, unsigned align_ldb) {
void gemm::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
if(AT){
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
if(AT_){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT){
if(BT_){
std::swap(BS0, BS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -71,8 +110,8 @@ const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only )" + a_ty + R"( *A,
restrict read_only )" + b_ty + R"( *B,
void matmul(restrict read_only )" + a_ty_ + R"( *A,
restrict read_only )" + b_ty_ + R"( *B,
fp32 *C,
int32 M, int32 N, int32 K,
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
@@ -87,10 +126,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
)" + a_ty + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
)" + b_ty + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
)" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
@@ -108,10 +147,10 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
)" + a_ty + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
)" + b_ty + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
)" + a_ty + R"( a[TM, 1] = checka ? *pa : 0;
)" + b_ty + R"( b[TN, 1] = checkb ? *pb : 0;
)" + a_ty_ + R"(* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
)" + b_ty_ + R"(* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
)" + a_ty_ + R"( a[TM, 1] = checka ? *pa : 0;
)" + b_ty_ + R"( b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
@@ -136,7 +175,7 @@ void matmul(restrict read_only )" + a_ty + R"( *A,
__atomic_cas(plock, 1, 0);
}
)";
return res;
os << res;
}
}

View File

@@ -5,21 +5,13 @@
namespace triton{
namespace dnn{
void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[size - 1] = 1;
for(int i = size - 1; i >= 1; i--)
ld[i - 1] = shapes[i] * ld[i];
}
shift::shift(int B, int C,
int D, int H, int W,
int T, int R, int S,
int F,
int stride_h, int stride_w,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
const int32_t *shift_h, const int32_t *shift_w,
std::string a_ty, std::string b_ty,
type ty, bool bias)
: base("shift"),
@@ -73,8 +65,10 @@ shift::shift(int B, int C,
set_ld(shapes_a_, ld_a_);
set_ld(shapes_b_, ld_b_);
set_ld(shapes_c_, ld_c_);
// build LUTs
build_deltas();
}
base* shift::clone() const {
return new shift(*this);
}
void shift::build_deltas() {
@@ -120,30 +114,29 @@ std::vector<int32_t> shift::c_shapes(){
return shapes_c_;
}
size_t shift::get_nflops() const {
size_t shift::num_flops() const {
return 2.*M_*N_*K_;
}
bool shift::operator <(const base& other) const{
auto *y = dynamic_cast<const shift*>(&other);
if(!y)
return false;
const int32_t *x_shift_h = shift_h_.data(), *x_shift_w = shift_w_.data();
const int32_t *y_shift_h = y->shift_h_.data(), *y_shift_w = y->shift_w_.data();
return true;
return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_,
x_shift_h, x_shift_w, ty_, bias_)
shift_h_, shift_w_, ty_, bias_)
< std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_,
y_shift_h, y_shift_w, y->ty_, y->bias_);
y->shift_h_, y->shift_w_, y->ty_, y->bias_);
}
void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
build_deltas();
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
}
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args,
size_t TM, size_t TN, size_t nthreads) {
const std::vector<unsigned> &ranges, size_t nthreads) {
int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_;
driver::buffer *a = args[0], *b = args[1], *c = args[2];
@@ -162,13 +155,14 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(12, AW_);
kernel->setArg(13, BH_);
kernel->setArg(14, BW_);
unsigned TM = ranges[0], TN = ranges[1];
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
if(ty_ == BPROP)
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
void shift::get_src(std::ostream &os) const {
void shift::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";