more work on heuristics
This commit is contained in:
@@ -6,19 +6,21 @@
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
template<class T>
|
||||
void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
for(size_t i = 0; i < x.size(); i++)
|
||||
if(std::isnan(x[i]) || std::abs(x[i] - y[i])/std::max(x[i], y[i]) > 1e-4){
|
||||
std::cout << i << " " << x[i] << " " << y[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
double bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float T;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 65536, N = 2048, K = 2048;
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> rc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
srand(0);
|
||||
@@ -36,14 +38,35 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// gemm.cpu_ref<T>(rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
// std::cout << "Pass!" << std::endl;
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
double result = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
struct config_t{
|
||||
bool AT;
|
||||
bool BT;
|
||||
int32_t M;
|
||||
int32_t N;
|
||||
int32_t K;
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
{false, false, 4096, 4096, 4096},
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
};
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
double tns = bench(context, c.AT, c.BT, c.M, c.N, c.K);
|
||||
double tflops = 2.*c.M*c.N*c.K / tns * 1e-3;
|
||||
std::cout << c.AT << ", " << c.BT << ", " << c.M << ", " << c.N << ", " << c.K << ", " << tflops << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -8,31 +8,23 @@
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/external/half.hpp"
|
||||
|
||||
int main() {
|
||||
double do_bench(triton::driver::context* context,
|
||||
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
|
||||
triton::dnn::shift::op_t op, triton::dnn::shift::layout_t layout,
|
||||
std::string numeric_t) {
|
||||
typedef float NumericT;
|
||||
std::string numeric_t_str = "fp16";
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 64, F = 2048;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 2048;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
for(int32_t c = 0; c < C; c++){
|
||||
shift_h[c] = 0;
|
||||
shift_w[c] = 0;
|
||||
shift_h[c] = rand() % R - R / 2;
|
||||
shift_w[c] = rand() % S - S / 2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
|
||||
shift_h.data(), shift_w.data(),
|
||||
numeric_t_str, numeric_t_str,
|
||||
numeric_t, numeric_t,
|
||||
op, false, triton::dnn::shift::CHWN);
|
||||
// host buffers
|
||||
size_t a_size = B*C*H*W;
|
||||
@@ -67,13 +59,19 @@ int main() {
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
shift.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
// for(size_t i = 0; i < hc.size(); i++)
|
||||
// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
// std::cout << "Pass!" << std::endl;
|
||||
double tns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, true);}, stream);
|
||||
std::cout << tns << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// shapes
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 4096;
|
||||
// benchmark
|
||||
do_bench(context, R, S, B, F, H, W, C, triton::dnn::shift::FPROP, triton::dnn::shift::CHWN, "fp16");
|
||||
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ torch::Tensor shift_common(
|
||||
int32_t T, int32_t R, int32_t S, int32_t F,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w,
|
||||
triton::dnn::shift::type ty, triton::dnn::shift::layout_t layout,
|
||||
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
|
@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
// template
|
||||
triton::dnn::gemm dot(M, N, K, false, false, "fp16", "fp16", 4, 4);
|
||||
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,7 @@
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
template<triton::dnn::shift::type OP>
|
||||
template<triton::dnn::shift::op_t OP>
|
||||
class ShiftConvOp : public OpKernel {
|
||||
public:
|
||||
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context), layout_(triton::dnn::shift::NCHW) {
|
||||
|
@@ -31,6 +31,13 @@ namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
|
||||
enum autotuning_t{
|
||||
FULL_TUNING,
|
||||
PARTIAL_TUNING,
|
||||
NO_TUNING
|
||||
};
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
|
||||
class base {
|
||||
friend class cmp_recompile;
|
||||
@@ -53,6 +60,9 @@ private:
|
||||
virtual size_t num_flops() const = 0;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const = 0;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
@@ -62,7 +72,7 @@ public:
|
||||
// clone
|
||||
virtual base* clone() const = 0;
|
||||
// enqueue
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, bool autotune = false);
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
@@ -6,7 +6,7 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class gemm: public base {
|
||||
class dot: public base {
|
||||
private:
|
||||
// initialize
|
||||
void init_impl(driver::stream *, driver::cu_module *);
|
||||
@@ -18,10 +18,12 @@ private:
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
|
||||
public:
|
||||
gemm(int M, int N, int K, bool AT, bool BT,
|
||||
dot(int M, int N, int K, bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb);
|
||||
|
||||
@@ -46,13 +48,13 @@ public:
|
||||
template<class T>
|
||||
void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
gemm::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
|
||||
dot::cpu_ref<T, true, true>(c, a, b, M_, N_, K_);
|
||||
else if(AT_ && !BT_)
|
||||
gemm::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
|
||||
dot::cpu_ref<T, true, false>(c, a, b, M_, N_, K_);
|
||||
else if(!AT_ && BT_)
|
||||
gemm::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
|
||||
dot::cpu_ref<T, false, true>(c, a, b, M_, N_, K_);
|
||||
else
|
||||
gemm::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
|
||||
dot::cpu_ref<T, false, false>(c, a, b, M_, N_, K_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -38,7 +38,7 @@ namespace dnn{
|
||||
class shift: public base {
|
||||
|
||||
public:
|
||||
enum type {
|
||||
enum op_t {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
@@ -56,6 +56,7 @@ private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
std::vector<unsigned> default_params() const;
|
||||
|
||||
public:
|
||||
|
||||
@@ -65,7 +66,7 @@ public:
|
||||
int stride_h, int stride_w,
|
||||
const int32_t* shift_h, const int32_t* shift_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
type ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
||||
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
||||
|
||||
// look-up table
|
||||
void build_delta_a();
|
||||
@@ -165,7 +166,7 @@ private:
|
||||
std::string b_ty_;
|
||||
std::string c_ty_;
|
||||
// convolution type
|
||||
type op_;
|
||||
op_t op_;
|
||||
bool bias_;
|
||||
// transpose
|
||||
bool AT_;
|
||||
|
@@ -167,6 +167,8 @@ public:
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
|
||||
|
||||
// SPIR-V libraries
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
||||
|
@@ -108,7 +108,7 @@ public:
|
||||
jit(driver::context* context, unsigned nthreads = 4);
|
||||
~jit();
|
||||
std::vector<unsigned> get_valid(const char *name, const char *src);
|
||||
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||
tune_res_t autotune(const char* name, const char* src, benchmark_t benchmark, const std::vector<std::vector<unsigned> > &targets = {});
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const char* name);
|
||||
|
@@ -2,6 +2,9 @@
|
||||
#define TRITON_TOOLS_BENCH_HPP
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
namespace triton{
|
||||
namespace tools{
|
||||
@@ -24,14 +27,14 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, const triton::driver::device * device)
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
{
|
||||
const driver::device * device = stream->context()->device();
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to get roughly constant result
|
||||
@@ -39,7 +42,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
stream->synchronize();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
|
@@ -529,8 +529,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 2);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 2);
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
@@ -221,7 +221,7 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
@@ -239,7 +239,7 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
|
@@ -7,8 +7,6 @@ namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
@@ -22,7 +20,15 @@ void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
base::base(const std::string& name)
|
||||
: name_(name) { }
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, bool autotune) {
|
||||
std::vector<params_t> base::search_space() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
params_t base::heuristics() const {
|
||||
return *search_space().begin();
|
||||
}
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
namespace rt = triton::runtime;
|
||||
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
@@ -30,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
|
||||
std::ostringstream oss;
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
@@ -40,18 +46,21 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
clone->enqueue_impl(stream, kernel, args, info);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
|
||||
clone->deinit_impl();
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark);
|
||||
if(autotune != NO_TUNING) {
|
||||
std::vector<params_t> space = {};
|
||||
if(autotune == PARTIAL_TUNING)
|
||||
space = search_space();
|
||||
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
params_t params = heuristics();
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
|
119
lib/dnn/gemm.cpp
119
lib/dnn/gemm.cpp
@@ -6,7 +6,7 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
gemm::gemm(int M, int N, int K,
|
||||
dot::dot(int M, int N, int K,
|
||||
bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb)
|
||||
@@ -18,13 +18,13 @@ gemm::gemm(int M, int N, int K,
|
||||
|
||||
}
|
||||
|
||||
size_t gemm::num_flops() const {
|
||||
size_t dot::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
bool gemm::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const gemm*>(&other);
|
||||
bool dot::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(M_, N_, K_, AT_, BT_,
|
||||
@@ -34,18 +34,18 @@ bool gemm::operator<(const base& other) const {
|
||||
}
|
||||
|
||||
// clone
|
||||
base* gemm::clone() const {
|
||||
return new gemm(*this);
|
||||
base* dot::clone() const {
|
||||
return new dot(*this);
|
||||
}
|
||||
|
||||
void gemm::init_impl(driver::stream* stream, driver::cu_module *) {
|
||||
void dot::init_impl(driver::stream* stream, driver::cu_module *) {
|
||||
std::vector<int32_t> hlocks(2048, 0);
|
||||
if(locks_ == nullptr)
|
||||
locks_ = triton::driver::buffer::create(stream->context(), hlocks.size()*4);
|
||||
stream->write(locks_, false, 0, hlocks);
|
||||
}
|
||||
|
||||
void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
runtime::launch_information info) {
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
@@ -75,7 +75,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
void gemm::triton_c_src(std::ostream &os) const {
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
@@ -100,8 +100,8 @@ void gemm::triton_c_src(std::ostream &os) const {
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128, 256};
|
||||
const tunable int32 TN = {32, 64, 128, 256};
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {32};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
@@ -145,5 +145,102 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
os << res;
|
||||
}
|
||||
|
||||
// small search space for partial auto-tuning
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
typedef std::vector<unsigned> params_t;
|
||||
typedef std::tuple<size_t, size_t> key_t;
|
||||
static std::vector<key_t> keys = {
|
||||
{16, 16}, {16, 32}, {16, 64}, {16, 128},
|
||||
{32, 16}, {32, 32}, {32, 64}, {32, 128},
|
||||
{64, 16}, {64, 32}, {64, 64}, {64, 128},
|
||||
{128, 16},{128, 32},{128, 64},{128, 128}
|
||||
};
|
||||
static std::vector<params_t> space_nn = {
|
||||
{4, 4, 16, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{2, 8, 16, 8, 32, 2, 2, 1, 1, 16, 32, 4, 8, 1},
|
||||
{4, 4, 16, 4, 64, 2, 2, 1, 1, 8, 32, 8, 4, 1},
|
||||
{4, 4, 16, 16, 128, 2, 2, 1, 2, 16, 32, 4, 8, 1},
|
||||
{4, 8, 32, 8, 16, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{4, 8, 32, 8, 32, 2, 2, 1, 1, 8, 32, 4, 8, 1},
|
||||
{8, 4, 32, 8, 64, 2, 2, 1, 1, 4, 32, 4, 8, 1},
|
||||
{8, 4, 32, 16, 128, 2, 2, 1, 4, 16, 32, 8, 4, 1},
|
||||
{8, 8, 64, 4, 16, 2, 2, 1, 1, 4, 32, 8, 4, 1},
|
||||
{8, 8, 64, 8, 32, 2, 2, 1, 1, 4, 32, 4, 8, 1},
|
||||
{8, 8, 64, 16, 64, 2, 2, 2, 1, 8, 32, 4, 8, 1},
|
||||
{16, 4, 64, 16, 128, 2, 2, 2, 2, 8, 32, 8, 4, 1},
|
||||
{8, 8, 128, 8, 16, 2, 2, 2, 1, 8, 32, 8, 4, 1},
|
||||
{8, 8, 128, 16, 32, 2, 2, 2, 1, 8, 32, 4, 8, 1},
|
||||
{8, 8, 128, 32, 64, 2, 2, 2, 2, 16, 32, 4, 8, 1},
|
||||
{8, 8, 128, 32, 128, 2, 2, 1, 4, 16, 32, 4, 8, 1},
|
||||
};
|
||||
static std::vector<params_t> space_nt = {
|
||||
{4, 4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 32, 16, 1},
|
||||
{4, 4, 16, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
|
||||
{4, 4, 16, 8, 8, 64, 2, 2, 1, 4, 32, 32, 16, 1},
|
||||
{4, 4, 16, 32, 4, 128, 2, 2, 1, 2, 16, 32, 2, 1},
|
||||
{8, 4, 32, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
|
||||
{4, 8, 32, 4, 8, 32, 2, 2, 1, 1, 8, 32, 8, 1},
|
||||
{16, 8, 128, 4, 4, 64, 2, 2, 1, 4, 8, 32, 32, 1},
|
||||
{4, 8, 32, 8, 8, 128, 2, 2, 1, 2, 16, 32, 8, 1},
|
||||
{8, 8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 32, 16, 1},
|
||||
{8, 8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 32, 8, 1},
|
||||
{8, 8, 64, 8, 8, 64, 2, 2, 1, 2, 8, 32, 8, 1},
|
||||
{8, 8, 64, 16, 8, 128, 2, 2, 1, 4, 16, 32, 8, 1},
|
||||
{8, 8, 128, 2, 8, 16, 2, 2, 2, 1, 8, 32, 32, 1},
|
||||
{16, 8, 128, 4, 8, 32, 2, 2, 2, 1, 4, 32, 16, 1},
|
||||
{8, 8, 128, 8, 8, 64, 2, 2, 2, 2, 16, 32, 16, 1},
|
||||
{8, 8, 128, 8, 8, 128, 2, 2, 4, 1, 16, 32, 16, 1},
|
||||
};
|
||||
static std::vector<params_t> space_tn = {
|
||||
{8, 16, 16, 16, 2, 2, 1, 1, 4, 8, 32, 2, 8, 1},
|
||||
{4, 16, 8, 32, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
|
||||
{4, 16, 4, 64, 2, 2, 1, 1, 8, 4, 32, 8, 4, 1},
|
||||
{16, 16, 16, 128, 2, 2, 1, 2, 4, 8, 32, 4, 8, 1},
|
||||
{4, 32, 8, 16, 2, 2, 1, 1, 8, 4, 32, 4, 8, 1},
|
||||
{8, 32, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{8, 32, 8, 64, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 32, 64, 128, 2, 2, 2, 2, 4, 8, 32, 2, 8, 1},
|
||||
{8, 64, 8, 16, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{8, 64, 8, 32, 2, 2, 1, 1, 4, 8, 32, 4, 8, 1},
|
||||
{16, 64, 16, 64, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 64, 16, 128, 2, 2, 2, 2, 4, 8, 32, 8, 4, 1},
|
||||
{16, 128, 16, 16, 2, 2, 2, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 32, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 64, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
{32, 128, 32, 128, 2, 2, 4, 1, 4, 8, 32, 4, 8, 1},
|
||||
};
|
||||
static std::vector<params_t> space_tt = {
|
||||
{4, 16, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
|
||||
{8, 16, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 16, 4, 8, 64, 2, 2, 1, 4, 8, 4, 32, 32, 1},
|
||||
{16, 16, 8, 4, 128, 2, 2, 1, 2, 4, 8, 32, 8, 1},
|
||||
{4, 32, 2, 8, 16, 2, 2, 1, 1, 8, 4, 32, 16, 1},
|
||||
{8, 32, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 64, 4, 8, 64, 2, 2, 2, 1, 4, 8, 32, 16, 1},
|
||||
{32, 32, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
|
||||
{8, 64, 2, 8, 16, 2, 2, 1, 1, 4, 8, 32, 16, 1},
|
||||
{8, 64, 4, 8, 32, 2, 2, 1, 1, 4, 8, 32, 8, 1},
|
||||
{16, 64, 8, 8, 64, 2, 2, 2, 1, 4, 8, 32, 8, 1},
|
||||
{32, 64, 8, 8, 128, 2, 2, 1, 4, 4, 8, 32, 16, 1},
|
||||
{16, 128, 2, 8, 16, 2, 2, 2, 1, 4, 8, 32, 32, 1},
|
||||
{32, 128, 8, 4, 32, 2, 2, 4, 1, 4, 8, 32, 16, 1},
|
||||
{32, 128, 16, 4, 64, 2, 2, 4, 1, 4, 8, 32, 8, 1},
|
||||
{32, 128, 8, 8, 128, 2, 2, 4, 1, 4, 8, 32, 16, 1}
|
||||
};
|
||||
if(!AT_ && !BT_)
|
||||
return space_nn;
|
||||
else if(!AT_ && BT_)
|
||||
return space_nt;
|
||||
else if(AT_ && !BT_)
|
||||
return space_tn;
|
||||
else
|
||||
return space_tt;
|
||||
}
|
||||
|
||||
// simple parameter heuristics
|
||||
params_t dot::heuristics() const {
|
||||
return search_space().back();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -13,7 +13,7 @@ shift::shift(int B, int C,
|
||||
int stride_h, int stride_w,
|
||||
const int32_t *shift_h, const int32_t *shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias,
|
||||
op_t ty, bool bias,
|
||||
layout_t layout)
|
||||
: base("shift"),
|
||||
B_(B), C_(C),
|
||||
@@ -512,5 +512,15 @@ else{
|
||||
os << result;
|
||||
}
|
||||
|
||||
|
||||
// simple parameter heuristics
|
||||
std::vector<unsigned> shift::default_params() const {
|
||||
typedef std::vector<unsigned> params_t;
|
||||
std::map<std::tuple<op_t, size_t, size_t>, params_t> params = {
|
||||
{{}, {}}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -31,7 +31,7 @@ extern triton::lang::translation_unit *ast_root;
|
||||
namespace triton {
|
||||
namespace runtime{
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges,
|
||||
void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
size_t nthreads){
|
||||
size_t D = ranges.size();
|
||||
@@ -55,7 +55,7 @@ void loop_nest(std::vector<size_t> const & ranges,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
void parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
@@ -68,10 +68,14 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
loop_nest(ranges, proxy, nthreads);
|
||||
parallel_loop_nest(ranges, proxy, nthreads);
|
||||
}
|
||||
|
||||
|
||||
void parallel_for_each(std::vector<std::vector<unsigned>> const & iterates, std::function<void(std::vector<unsigned>)> const & f, size_t nthreads) {
|
||||
ThreadPool pool(nthreads);
|
||||
for(const std::vector<unsigned>& values: iterates)
|
||||
pool.enqueue(f, values);
|
||||
}
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes, llvm::LLVMContext& llvm_context, launch_information& info) {
|
||||
@@ -128,7 +132,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
std::vector<unsigned> result;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
parallel_loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
if(!result.empty())
|
||||
return;
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
@@ -148,7 +152,7 @@ std::vector<unsigned> jit::get_valid(const char *name, const char *src) {
|
||||
|
||||
|
||||
|
||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||
jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t benchmark, const std::vector<std::vector<unsigned>> & targets) {
|
||||
// find metaparameters
|
||||
triton::lang::translation_unit* program = parse_program(name, src);
|
||||
auto ptt_module_0 = make_triton_module(name, triton_context_, program);
|
||||
@@ -157,15 +161,12 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
passes_wrapper passes_0(target_.get());
|
||||
passes_0.target_independent(tt_module_0);
|
||||
passes_0.tune.run(tt_module_0);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
auto mps = passes_0.tune.get_params(tt_module_0);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
std::mutex mutex;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
// update_best
|
||||
auto update_best = [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
{
|
||||
@@ -200,10 +201,10 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
launch_information info;
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
double perf;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
perf = benchmark(kernel.get(), info);
|
||||
if(perf > best.perf){
|
||||
@@ -214,8 +215,21 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
}, nthreads_);
|
||||
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
};
|
||||
|
||||
|
||||
if(targets.empty()) {
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
parallel_loop_nest<unsigned>(ranges, update_best, nthreads_);
|
||||
}
|
||||
else {
|
||||
parallel_for_each(targets, update_best, nthreads_);
|
||||
}
|
||||
|
||||
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
return best;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user