better benchmarking

This commit is contained in:
Philippe Tillet
2019-07-22 19:26:12 -07:00
parent ead368d1ed
commit c448876178
12 changed files with 159 additions and 48 deletions

View File

@@ -2,5 +2,5 @@ foreach(PROG dot conv shift)
add_executable(${PROG} ${PROG}.cpp)
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
include_directories(/usr/local/cuda/include/)
target_link_libraries(${PROG} triton)
target_link_libraries(${PROG} triton cublas)
endforeach(PROG)

View File

@@ -6,6 +6,7 @@
#include "triton/driver/stream.h"
#include "triton/dnn/gemm.h"
#include "triton/tools/bench.hpp"
#include "cuda.h"
template<class T>
void diff(const std::vector<T>& x, const std::vector<T>& y){
@@ -17,34 +18,63 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
std::cout << "Pass!" << std::endl;
}
double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float T;
struct perf_t {
double triton;
double cublas;
};
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
std::vector<T> hc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context();
std::vector<NumericT> hc(M*N);
std::vector<NumericT> ha(M*K);
std::vector<NumericT> hb(K*N);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (T)rand()/RAND_MAX;
ha[i] = (NumericT)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (T)rand()/RAND_MAX;
hb[i] = (NumericT)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = M;
cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
// &alpha, da, lda,
// db, ldb, &beta,
// dc, ldc, &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
&alpha, da, lda,
db, ldb, &beta,
dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
// result
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
perf_t result;
result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
// clean-up
delete dc;
delete da;
delete db;
return dot.num_flops() / nanosec * 1e-3;
return result;
}
int main() {
@@ -61,21 +91,24 @@ int main() {
return oss.str();
}
double perf(triton::driver::context *context){
return do_bench(context, AT, BT, M, N, K);
perf_t perf(triton::driver::stream *stream){
return do_bench(stream, AT, BT, M, N, K);
}
};
// shapes to benchmark
std::vector<config_t> configs = {
{false, false, 4096, 4096, 4096},
{false, true, 4096, 4096, 4096},
{true, false, 4096, 4096, 4096},
{true, true, 4096, 4096, 4096}
// {false, false, 8192, 512, 512},
{false, true, 8192, 8192, 8192},
{false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}
};
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// does the work
for(config_t c: configs){
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
perf_t perf = c.perf(stream);
std::cout << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
}
}

View File

@@ -1,6 +1,7 @@
#include <cstring>
#include <cstdio>
#include <sstream>
#include "cuda.h"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
@@ -8,12 +9,20 @@
#include "triton/dnn/shift.h"
#include "triton/external/half.hpp"
double do_bench(triton::driver::context* context,
struct perf_t {
double triton;
double cublas;
};
perf_t do_bench(triton::driver::stream *stream,
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
triton::dnn::op_t op, triton::dnn::layout_t layout,
std::string numeric_t) {
typedef float NumericT;
// driver variables
triton::driver::context* context = stream->context();
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
@@ -44,7 +53,6 @@ double do_bench(triton::driver::context* context,
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*sizeof(NumericT));
triton::driver::stream* stream = triton::driver::stream::create(context);
// initialize host
srand(0);
for(size_t i = 0; i < ha.size(); i++)
@@ -58,8 +66,29 @@ double do_bench(triton::driver::context* context,
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream);
return shift.num_flops() / nanosec * 1e-3;
// benchmark triton
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
cublasGemmAlgo_t fastest;
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(), &beta,
dc, shift.ldc(), &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(),
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// result
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
perf_t result;
result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
delete da;
delete db;
delete dc;
return result;
}
int main() {
@@ -86,13 +115,15 @@ int main() {
return oss.str();
}
double perf(triton::driver::context *context){
return do_bench(context, R, S, B, F, H, W, C, op, layout, ty);
perf_t perf(triton::driver::stream *stream){
return do_bench(stream, R, S, B, F, H, W, C, op, layout, ty);
}
};
// shapes to benchmark
std::vector<config_t> configs;
std::vector<config_t> resnet18 = {
std::vector<config_t> resnet18 =
{
{128, 128, 32, 32, 3, 3, 128, 1, 1},
{128, 128, 32, 32, 3, 3, 128, 1, 1},
{128, 128, 32, 32, 3, 3, 256, 2, 2},
{128, 256, 16, 16, 3, 3, 256, 1, 1},
@@ -108,7 +139,11 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
for(config_t c: configs)
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
triton::driver::stream *stream = triton::driver::stream::create(context);
for(config_t c: configs){
std::string repr = c.repr();
perf_t perf = c.perf(stream);
std::cout << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
}
}

View File

@@ -99,9 +99,10 @@ inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
// simple parameter heuristics
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
size_t TM = 128;
size_t TN = 128;
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
// size_t TM = 128;
// size_t TN = 128;
return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1};
// return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
}
}

View File

@@ -73,6 +73,15 @@ public:
// accessors
size_t c_size();
std::vector<int32_t> c_shapes();
// equivalent GEMM
bool AT() const;
bool BT() const;
size_t M() const;
size_t N() const;
size_t K() const;
size_t lda() const;
size_t ldb() const;
size_t ldc() const;
// number of flops
size_t num_flops() const;
// source

View File

@@ -781,9 +781,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
distributed_tile *mask_tile;
if(mask)
mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
ptr->for_each([&](indices_t idx){
set_mask_insert_pt(idx);
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
Value *ptr_value = ptr->get_value(idx);
Value *value_value = value->get_value(idx);
Instruction *store;
// if(mask){
// Value *pred_value = mask_tile->get_value(idx);
// value_value = builder.CreateVectorSplat(1, value_value);
// pred_value = builder.CreateVectorSplat(1, pred_value);
// Type *ptr_ty = PointerType::get(value_value->getType(), ptr_value->getType()->getPointerAddressSpace());
// ptr_value = builder.CreateBitCast(ptr_value, ptr_ty);
// store = builder.CreateMaskedStore(value_value, ptr_value, 1, pred_value);
// }
// else
store = new StoreInst(value_value, ptr_value);
builder.Insert(store);
});
}

View File

@@ -215,7 +215,7 @@ void tune::run(ir::module &mod) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
@@ -239,7 +239,7 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}

View File

@@ -36,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
base* clone = this->clone();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
std::ostringstream oss;
clone->triton_c_src(oss);
std::string src = oss.str();

View File

@@ -129,10 +129,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
int1 checka[)" + AS + R"(] = k > bound;
int1 checkb[)" + BS + R"(] = k > bound;
@checka a = *pa;
@checkb b = *pb;
a = *pa;
b = *pb;
}
int32 rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = ridy*TN + (0 ... TN);

View File

@@ -180,6 +180,30 @@ size_t shift::num_flops() const {
return 2.*M_*N_*K_;
}
bool shift::AT() const
{ return AT_; }
bool shift::BT() const
{ return BT_; }
size_t shift::M() const
{ return M_; }
size_t shift::N() const
{ return N_; }
size_t shift::K() const
{ return K_; }
size_t shift::lda() const
{ return AT_ ? K_ : M_; }
size_t shift::ldb() const
{ return BT_ ? N_ : K_; }
size_t shift::ldc() const
{ return M_; }
bool shift::operator <(const base& other) const{
auto *y = dynamic_cast<const shift*>(&other);
if(!y)
@@ -265,10 +289,6 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(30, (int32_t)grid[2]);
if(locks_)
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
if(op_ == FPROP || op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
}
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
}

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -164,8 +164,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
auto mps = passes_0.tune.get_params(tt_module_0);
// iterate over parameters
tune_res_t best;
std::mutex mutex;
// update_best
std::mutex mutex;
auto update_best = [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
unsigned i = 0;
@@ -211,9 +211,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf;
best.params = params;
}
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
};