stuff
This commit is contained in:
@@ -74,12 +74,12 @@ int main() {
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
16, 4, 64, 16, 4, 128, 2, 2, 1, 2, 4, 4, 16, 4
|
||||
4, 2, 16, 4, 128, 2, 2, 1, 1, 8, 16, 8, 2
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
jit.autotune("shift", src.c_str(), benchmark);
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
import numpy as np
|
||||
from time import time
|
||||
|
||||
data_files_path = tf.resource_loader.get_data_files_path()
|
||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||
@@ -42,23 +44,45 @@ def run_conv():
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
|
||||
|
||||
@ops.RegisterGradient('ShiftConv')
|
||||
def blocksparse_matmul_grad(op, dy):
|
||||
shift_h = op.get_attr('shift_h')
|
||||
shift_w = op.get_attr('shift_w')
|
||||
x = op.inputs[0]
|
||||
w = op.inputs[1]
|
||||
dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w)
|
||||
dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w)
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 16, 32, 32, 32
|
||||
R, S, F = 3, 3, 32
|
||||
B, C, H, W = 1, 16, 8, 8
|
||||
R, S, F = 3, 3, 16
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
shift_h = tf.zeros(C, tf.int32)
|
||||
shift_w = tf.zeros(C, tf.int32)
|
||||
hshift_h = np.zeros(C, np.int32)
|
||||
hshift_w = np.zeros(C, np.int32)
|
||||
#hshift_h = np.random.randint(-R//2, R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_w = np.random.randint(-S//2, R//2 + 1, size=C, dtype=np.int32)
|
||||
hshift_h = 0*np.ones(C, dtype=np.int32)
|
||||
hshift_w = 0*np.ones(C, dtype=np.int32)
|
||||
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# Reference
|
||||
ha = np.random.rand(C, H, W, B)
|
||||
hb = np.random.rand(C, F)
|
||||
# Run
|
||||
ha = np.ones((C, H, W, B), dtype=np.int32)
|
||||
hb = np.ones((C, F), dtype=np.int32)
|
||||
sess = tf.InteractiveSession()
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (C, H, W, B),
|
||||
extra_feed_dict={a: ha, b: hb})
|
||||
dx_t, dx_n = grads[0]
|
||||
dw_t, dw_n = grads[1]
|
||||
print(dw_t)
|
||||
print(dw_n)
|
||||
#print(np.max(dw_t - dw_n))
|
||||
#print(np.max(dx_t - dx_n))
|
||||
np.savetxt('theoretical.dat', dw_t, fmt='%4.2f')
|
||||
np.savetxt('numerical.dat', dw_n, fmt='%4.2f')
|
||||
# Run
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
#print(result)
|
||||
|
||||
run_shift()
|
||||
|
@@ -19,6 +19,15 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t*, int32_t*,
|
||||
triton::dnn::shift::type, bool> shift_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_config;
|
||||
|
||||
template<triton::dnn::shift::type OP>
|
||||
class ShiftConvOp : public OpKernel {
|
||||
public:
|
||||
@@ -78,15 +87,27 @@ public:
|
||||
// shapes
|
||||
int64_t C, H, W, B, F;
|
||||
FillShapes(context, C, H, W, B, F, tf_a, tf_b);
|
||||
int64_t D = 1, T = 1;
|
||||
bool has_bias = false;
|
||||
// shift configuration
|
||||
int32_t* shift_h_data = h_shift_h_.flat<int32_t>().data();
|
||||
int32_t* shift_w_data = h_shift_w_.flat<int32_t>().data();
|
||||
std::vector<int32_t> shift_h(shift_h_data, shift_h_data + C);
|
||||
std::vector<int32_t> shift_w(shift_w_data, shift_w_data + C);
|
||||
triton::dnn::shift shift(B, C, 1, H, W, 1, R_, S_, F, shift_h, shift_w, "fp32", "fp32", OP, false);
|
||||
shift_key_t key = {B, C, 1, H, W, 1, R_, S_, F, shift_h_data, shift_w_data, OP, has_bias};
|
||||
// create configuration
|
||||
triton::dnn::shift* shift;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
shift = m_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R_, S_, F,
|
||||
shift_h, shift_w, "fp32", "fp32", OP, has_bias))
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
for(int32_t x: shift.c_shapes())
|
||||
for(int32_t x: shift->c_shapes())
|
||||
c_shapes.push_back(x);
|
||||
TensorShape out_shapes(c_shapes);
|
||||
Tensor* tf_c = nullptr;
|
||||
@@ -94,38 +115,58 @@ public:
|
||||
// return early if possible
|
||||
if (out_shapes.num_elements() == 0)
|
||||
return;
|
||||
// initialize default compute device
|
||||
triton::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ shift.enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return shift.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
triton::jit::tune_res_t best = jit.autotune("shift", src.c_str(), benchmark);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = false;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
shift->src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
return shift->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
||||
jit->add_module("shift", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
shift->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
shift->enqueue(stream, kernel, &da, &db, &dc, TM, TN, nthreads);
|
||||
}
|
||||
|
||||
private:
|
||||
Tensor h_shift_h_;
|
||||
Tensor h_shift_w_;
|
||||
// triton::driver::buffer* d_shift_h_;
|
||||
// triton::driver::buffer* d_shift_w_;
|
||||
int R_;
|
||||
int S_;
|
||||
};
|
||||
@@ -136,5 +177,21 @@ REGISTER_OP("ShiftConv")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Output("c: float32")
|
||||
;
|
||||
.Output("c: float32");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
|
||||
REGISTER_OP("ShiftConvDx")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Output("c: float32");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
|
||||
REGISTER_OP("ShiftConvDw")
|
||||
.Input("a: float32")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Output("c: float32");
|
||||
|
||||
|
@@ -103,6 +103,7 @@ private:
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
~jit();
|
||||
std::vector<unsigned> get_valid(const char *name, const char *src);
|
||||
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
|
@@ -70,16 +70,26 @@ shift::shift(int B, int C,
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
|
||||
};
|
||||
h_deltas_.resize(MAX_C_);
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
if(ty_ == FPROP){
|
||||
// compute offset
|
||||
auto offset = [&](unsigned c) {
|
||||
return c*ld_a_[0] + shift_h_[c]*ld_a_[1] + shift_w_[c]*ld_a_[2];
|
||||
};
|
||||
// populate look-up table
|
||||
for(unsigned c = 0; c < TK_; c++)
|
||||
h_deltas_[c] = offset(c);
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[TK_ + c] = offset(c + TK_) - offset(c);
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[c] = shift_h_[c]*ld_c_[1] + shift_w_[c]*ld_c_[2];
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
for(unsigned c = 0; c < C_; c++)
|
||||
h_deltas_[c] = shift_h_[c]*ld_b_[1] + shift_w_[c]*ld_b_[2];
|
||||
}
|
||||
}
|
||||
|
||||
size_t shift::a_size(){
|
||||
@@ -102,7 +112,7 @@ std::vector<int32_t> shift::c_shapes(){
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() {
|
||||
return 2. * M_ * N_ * K_;
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
|
||||
@@ -114,15 +124,13 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
if(ty_ == WGRAD)
|
||||
std::swap(a, b);
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, B_*AH_*AW_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, B_);
|
||||
kernel->setArg(9, AH_);
|
||||
@@ -177,7 +185,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
multiple_of(4) int32 lda, multiple_of(4) int32 ldb,
|
||||
int32 lda, int32 ldb,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
@@ -203,11 +211,13 @@ if(ty_ == FPROP){
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 shift[TK, TN] = 0;)";
|
||||
__constant__ int32* pd[TN] = delta + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];)";
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + rxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = *pa;
|
||||
)" << b_ty_ << " b[" << BS << R"(] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
@@ -239,7 +249,7 @@ if(ty_ == WGRAD){
|
||||
int1 maskw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
int1 mask[TK, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
int32 inc[TK, TN] = mask ? 0 : shift;
|
||||
pb = pb + TK;
|
||||
pb = pb + TK)" << ldb0 << R"(;
|
||||
)" << b_ty_ << R"(* pbb[TK, TN] = pb + inc;
|
||||
@checkb b = *pbb;)";
|
||||
}
|
||||
@@ -259,14 +269,15 @@ else{
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
int32 rcwhc[TM] = rxc / ABS;
|
||||
int32 rcw[TM] = rcwhc % AW;
|
||||
int32 rcw[TM] = (rcwhc % AW);
|
||||
int32 rchc[TM] = rcwhc / AW;
|
||||
int32 rch[TM] = rchc % AH;
|
||||
int32 rch[TM] = (rchc % AH);
|
||||
int1 maskh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
||||
int1 maskw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
||||
int1 interior[TM, TN] = maskh[:, newaxis] && maskw[:, newaxis];
|
||||
fp32* shiftpc[TM, TN] = pc + 0;
|
||||
pc = interior ? shiftpc : pc;
|
||||
__constant__ int32* pd[TN] = delta + ryc;
|
||||
fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
pc = interior ? shift_pc : pc;
|
||||
@checkc __atomic_add(pc, C);
|
||||
)";
|
||||
}
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << sd::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -96,6 +96,46 @@ jit::jit(driver::context *context): driver_context_(context),
|
||||
|
||||
jit::~jit(){ }
|
||||
|
||||
std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
std::vector<unsigned> result;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
if(!result.empty())
|
||||
return;
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
passes.tune.check_constraints(errors);
|
||||
// for(auto e: errors)
|
||||
// for(auto x: e.second)
|
||||
// std::cout << x << std::endl;
|
||||
// std::cout << "-----" << std::endl;
|
||||
if(!errors.empty())
|
||||
return;
|
||||
result = params;
|
||||
});
|
||||
if(result.empty())
|
||||
throw std::runtime_error("couldn't find valid parameters");
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
|
Reference in New Issue
Block a user