more work on heuristics

This commit is contained in:
Philippe Tillet
2019-07-21 18:11:54 -07:00
parent 484e3871cf
commit b1d81a5802
17 changed files with 268 additions and 99 deletions

View File

@@ -6,19 +6,21 @@
#include "triton/dnn/gemm.h"
#include "triton/tools/bench.hpp"
template<class T>
void diff(const std::vector<T>& x, const std::vector<T>& y){
for(size_t i = 0; i < x.size(); i++)
if(std::isnan(x[i]) || std::abs(x[i] - y[i])/std::max(x[i], y[i]) > 1e-4){
std::cout << i << " " << x[i] << " " << y[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}
int main() {
bool AT = false;
bool BT = true;
double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float T;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 65536, N = 2048, K = 2048;
std::vector<T> hc(M*N);
std::vector<T> rc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
srand(0);
@@ -36,14 +38,35 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
gemm.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// gemm.cpu_ref<T>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
delete dc;
delete da;
delete db;
return result;
}
int main() {
struct config_t{
bool AT;
bool BT;
int32_t M;
int32_t N;
int32_t K;
};
// shapes to benchmark
std::vector<config_t> configs = {
{false, false, 4096, 4096, 4096},
{false, true, 4096, 4096, 4096},
{true, false, 4096, 4096, 4096},
{true, true, 4096, 4096, 4096}
};
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// does the work
for(config_t c: configs){
double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K);
double tflops = 2.*c.M*c.N*c.K / tns * 1e-3;
std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl;
}
}

View File

@@ -8,31 +8,23 @@
#include "triton/dnn/shift.h"
#include "triton/external/half.hpp"
int main() {
double do_bench(triton::driver::context* context,
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout,
std::string numeric_t) {
typedef float NumericT;
std::string numeric_t_str = "fp16";
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::FPROP;
// initialization
int32_t R = 3, S = 3;
int32_t B = 64, F = 2048;
int32_t H = 32, W = 32;
int32_t C = 2048;
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
for(int32_t c = 0; c < C; c++){
shift_h[c] = 0;
shift_w[c] = 0;
shift_h[c] = rand() % R - R / 2;
shift_w[c] = rand() % S - S / 2;
}
// configuration
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t_str, numeric_t_str,
numeric_t, numeric_t,
op, false, triton::dnn::shift::CHWN);
// host buffers
size_t a_size = B*C*H*W;
@@ -67,13 +59,19 @@ int main() {
stream->write(dc, true, 0, hc);
stream->synchronize();
shift.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
// for(size_t i = 0; i < hc.size(); i++)
// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream);
std::cout << tns << std::endl;
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// shapes
int32_t R = 3, S = 3;
int32_t B = 16, F = 4096;
int32_t H = 32, W = 32;
int32_t C = 4096;
// benchmark
do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16");
}

View File

@@ -36,7 +36,7 @@ torch::Tensor shift_common(
int32_t T, int32_t R, int32_t S, int32_t F,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w,
triton::dnn::shift::type ty, triton::dnn::shift::layout_t layout,
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
bool autotune = false
) {

View File

@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// template
triton::dnn::gemm dot(M, N, K, false, false, "fp16", "fp16", 4, 4);
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
dot.enqueue(stream, {&da, &db, &dc});
}

View File

@@ -19,7 +19,7 @@
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
template<triton::dnn::shift::type OP>
template<triton::dnn::shift::op_t OP>
class ShiftConvOp : public OpKernel {
public:
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) {

View File

@@ -31,6 +31,13 @@ namespace triton{
namespace dnn{
enum autotuning_t{
FULL_TUNING,
PARTIAL_TUNING,
NO_TUNING
};
typedef std::vector<unsigned> params_t;
class base {
friend class cmp_recompile;
@@ -53,6 +60,9 @@ private:
virtual size_t num_flops() const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
public:
// constructor
@@ -62,7 +72,7 @@ public:
// clone
virtual base* clone() const = 0;
// enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, bool autotune = false);
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
private:
std::string name_;

View File

@@ -6,7 +6,7 @@
namespace triton{
namespace dnn{
class gemm: public base {
class dot: public base {
private:
// initialize
void init_impl(driver::stream *, driver::cu_module *);
@@ -18,10 +18,12 @@ private:
size_t num_flops() const;
// comparison for maps
bool operator<(const base& other) const;
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
public:
gemm(int M, int N, int K, bool AT, bool BT,
dot(int M, int N, int K, bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb);
@@ -46,13 +48,13 @@ public:
template<class T>
void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT_ && BT_)
gemm::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
dot::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
else if(AT_ && !BT_)
gemm::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
dot::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
else if(!AT_ && BT_)
gemm::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
dot::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
else
gemm::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
dot::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
}
private:

View File

@@ -38,7 +38,7 @@ namespace dnn{
class shift: public base {
public:
enum type {
enum op_t {
FPROP,
BPROP,
WGRAD
@@ -56,6 +56,7 @@ private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
std::vector<unsigned> default_params() const;
public:
@@ -65,7 +66,7 @@ public:
int stride_h, int stride_w,
const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false, layout_t layout = CHWN);
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
// look-up table
void build_delta_a();
@@ -165,7 +166,7 @@ private:
std::string b_ty_;
std::string c_ty_;
// convolution type
type op_;
op_t op_;
bool bias_;
// transpose
bool AT_;

View File

@@ -167,6 +167,8 @@ public:
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);

View File

@@ -108,7 +108,7 @@ public:
jit(driver::context* context, unsigned nthreads = 4);
~jit();
std::vector<unsigned> get_valid(const char *name, const char *src);
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector<std::vector<unsigned> > &targets = {});
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
driver::kernel* get_function(const char* name);

View File

@@ -2,6 +2,9 @@
#define TRITON_TOOLS_BENCH_HPP
#include <chrono>
#include <functional>
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
namespace triton{
namespace tools{
@@ -24,14 +27,14 @@ private:
high_resolution_clock::time_point _start;
};
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, const triton::driver::device * device)
inline double bench(std::function<void()> const & op, driver::stream * stream)
{
const driver::device * device = stream->context()->device();
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
stream->synchronize();
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
@@ -39,7 +42,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
sync();
stream->synchronize();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}

View File

@@ -529,8 +529,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 2);
pack_size_1_ = std::min<unsigned>(num_rep_1, 2);
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;

View File

@@ -221,7 +221,7 @@ void tune::run(ir::module &mod) {
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
@@ -239,7 +239,7 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}

View File

@@ -7,8 +7,6 @@ namespace triton{
namespace dnn{
void base::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
@@ -22,7 +20,15 @@ void base::set_ld(const std::vector<int32_t>& shapes,
base::base(const std::string& name)
: name_(name) { }
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
std::vector<params_t> base::search_space() const {
return {};
}
params_t base::heuristics() const {
return *search_space().begin();
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
namespace rt = triton::runtime;
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
driver::context* ctx = stream->context();
@@ -30,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
base* clone = this->clone();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
std::ostringstream oss;
clone->triton_c_src(oss);
std::string src = oss.str();
@@ -40,18 +46,21 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->enqueue_impl(stream, kernel, args, info);
stream->synchronize();
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
[&](){ stream->synchronize(); }, ctx->device());
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
clone->deinit_impl();
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
if(autotune != NO_TUNING) {
std::vector<params_t> space = {};
if(autotune == PARTIAL_TUNING)
space = search_space();
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
params_t params = heuristics();
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());

View File

@@ -6,7 +6,7 @@
namespace triton{
namespace dnn{
gemm::gemm(int M, int N, int K,
dot::dot(int M, int N, int K,
bool AT, bool BT,
std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb)
@@ -18,13 +18,13 @@ gemm::gemm(int M, int N, int K,
}
size_t gemm::num_flops() const {
size_t dot::num_flops() const {
return 2.*M_*N_*K_;
}
// comparison for maps
bool gemm::operator<(const base& other) const {
auto *y = dynamic_cast<const gemm*>(&other);
bool dot::operator<(const base& other) const {
auto *y = dynamic_cast<const dot*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_, AT_, BT_,
@@ -34,18 +34,18 @@ bool gemm::operator<(const base& other) const {
}
// clone
base* gemm::clone() const {
return new gemm(*this);
base* dot::clone() const {
return new dot(*this);
}
void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
void dot::init_impl(driver::stream* stream, driver::cu_module *) {
std::vector<int32_t> hlocks(2048, 0);
if(locks_ == nullptr)
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
stream->write(locks_, false, 0, hlocks);
}
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
runtime::launch_information info) {
driver::buffer *a = args[0], *b = args[1], *c = args[2];
@@ -75,7 +75,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
}
void gemm::triton_c_src(std::ostream &os) const {
void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
@@ -100,8 +100,8 @@ void gemm::triton_c_src(std::ostream &os) const {
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int32 TM = {32, 64, 128, 256};
const tunable int32 TN = {32, 64, 128, 256};
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {32};
const tunable int32 GZ = {1};
@@ -145,5 +145,102 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
os << res;
}
// small search space for partial auto-tuning
std::vector<params_t> dot::search_space() const {
typedef std::vector<unsigned> params_t;
typedef std::tuple<size_t, size_t> key_t;
static std::vector<key_t> keys = {
{16, 16}, {16, 32}, {16, 64}, {16, 128},
{32, 16}, {32, 32}, {32, 64}, {32, 128},
{64, 16}, {64, 32}, {64, 64}, {64, 128},
{128, 16},{128, 32},{128, 64},{128, 128}
};
static std::vector<params_t> space_nn = {
{4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1},
{4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1},
{4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1},
{4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1},
{8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1},
{8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1},
{8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1},
{8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1},
{8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1},
{8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1},
{8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1},
{8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1},
};
static std::vector<params_t> space_nt = {
{4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1},
{4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1},
{4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1},
{8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
{16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1},
{4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1},
{8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
{8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1},
{8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1},
{8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1},
{8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1},
{16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1},
{8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1},
{8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1},
};
static std::vector<params_t> space_tn = {
{8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1},
{4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1},
{16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1},
{4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
{8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1},
{8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
{16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1},
{16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
{32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
};
static std::vector<params_t> space_tt = {
{4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1},
{16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1},
{4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
{8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1},
{32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1},
{8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
{16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1},
{32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
{16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1},
{32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1},
{32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1},
{32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}
};
if(!AT_ && !BT_)
return space_nn;
else if(!AT_ && BT_)
return space_nt;
else if(AT_ && !BT_)
return space_tn;
else
return space_tt;
}
// simple parameter heuristics
params_t dot::heuristics() const {
return search_space().back();
}
}
}

View File

@@ -13,7 +13,7 @@ shift::shift(int B, int C,
int stride_h, int stride_w,
const int32_t *shift_h, const int32_t *shift_w,
std::string a_ty, std::string b_ty,
type ty, bool bias,
op_t ty, bool bias,
layout_t layout)
: base("shift"),
B_(B), C_(C),
@@ -512,5 +512,15 @@ else{
os << result;
}
// simple parameter heuristics
std::vector<unsigned> shift::default_params() const {
typedef std::vector<unsigned> params_t;
std::map<std::tuple<op_t, size_t, size_t>, params_t> params = {
{{}, {}}
};
}
}
}

View File

@@ -31,7 +31,7 @@ extern triton::lang::translation_unit *ast_root;
namespace triton {
namespace runtime{
void loop_nest(std::vector<size_t> const & ranges,
void parallel_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f,
size_t nthreads){
size_t D = ranges.size();
@@ -55,7 +55,7 @@ void loop_nest(std::vector<size_t> const & ranges,
}
template<class T>
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
void parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
//Ranges to iterate over
std::vector<size_t> ranges;
for(auto const & x: iterates)
@@ -68,10 +68,14 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
f(x);
};
//Iterate
loop_nest(ranges, proxy, nthreads);
parallel_loop_nest(ranges, proxy, nthreads);
}
void parallel_for_each(std::vector<std::vector<unsigned>> const & iterates, std::function<void(std::vector<unsigned>)> const & f, size_t nthreads) {
ThreadPool pool(nthreads);
for(const std::vector<unsigned>& values: iterates)
pool.enqueue(f, values);
}
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
@@ -128,7 +132,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
ranges.push_back(mp->get_space());
// iterate over parameters
std::vector<unsigned> result;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
parallel_loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
if(!result.empty())
return;
std::map<ir::value*, std::vector<std::string>> errors;
@@ -148,7 +152,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark, const std::vector<std::vector<unsigned>> & targets) {
// find metaparameters
triton::lang::translation_unit* program = parse_program(name, src);
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
@@ -157,15 +161,12 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
passes_wrapper passes_0(target_.get());
passes_0.target_independent(tt_module_0);
passes_0.tune.run(tt_module_0);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
auto mps = passes_0.tune.get_params(tt_module_0);
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// iterate over parameters
tune_res_t best;
std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
// update_best
auto update_best = [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
unsigned i = 0;
{
@@ -200,10 +201,10 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
launch_information info;
llvm::LLVMContext llvm_context;
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
double perf;
{
std::lock_guard<std::mutex> lock(mutex);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
perf = benchmark(kernel.get(), info);
if(perf > best.perf){
@@ -214,8 +215,21 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
}, nthreads_);
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
};
if(targets.empty()) {
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
parallel_loop_nest<unsigned>(ranges, update_best, nthreads_);
}
else {
parallel_for_each(targets, update_best, nthreads_);
}
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
return best;
}