better benchmarking
This commit is contained in:
@@ -2,5 +2,5 @@ foreach(PROG dot conv shift)
|
||||
add_executable(${PROG} ${PROG}.cpp)
|
||||
set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG})
|
||||
include_directories(/usr/local/cuda/include/)
|
||||
target_link_libraries(${PROG} triton)
|
||||
target_link_libraries(${PROG} triton cublas)
|
||||
endforeach(PROG)
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "cuda.h"
|
||||
|
||||
template<class T>
|
||||
void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
@@ -17,34 +18,63 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
double do_bench(triton::driver::context* context, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float T;
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "fp16";
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
std::vector<NumericT> ha(M*K);
|
||||
std::vector<NumericT> hb(K*N);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = (T)rand()/RAND_MAX;
|
||||
ha[i] = (NumericT)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (T)rand()/RAND_MAX;
|
||||
hb[i] = (NumericT)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
double nanosec = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
&alpha, da, lda,
|
||||
db, ldb, &beta,
|
||||
dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
|
||||
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
// clean-up
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return dot.num_flops() / nanosec * 1e-3;
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -61,21 +91,24 @@ int main() {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, AT, BT, M, N, K);
|
||||
perf_t perf(triton::driver::stream *stream){
|
||||
return do_bench(stream, AT, BT, M, N, K);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
{false, false, 4096, 4096, 4096},
|
||||
{false, true, 4096, 4096, 4096},
|
||||
{true, false, 4096, 4096, 4096},
|
||||
{true, true, 4096, 4096, 4096}
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 8192, 8192, 8192},
|
||||
{false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
};
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// does the work
|
||||
for(config_t c: configs){
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include "cuda.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
@@ -8,12 +9,20 @@
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/external/half.hpp"
|
||||
|
||||
double do_bench(triton::driver::context* context,
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
perf_t do_bench(triton::driver::stream *stream,
|
||||
int32_t R, int32_t S, int32_t B, int32_t F, int32_t H, int32_t W, int32_t C,
|
||||
triton::dnn::op_t op, triton::dnn::layout_t layout,
|
||||
std::string numeric_t) {
|
||||
typedef float NumericT;
|
||||
|
||||
// driver variables
|
||||
triton::driver::context* context = stream->context();
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
@@ -44,7 +53,6 @@ double do_bench(triton::driver::context* context,
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*sizeof(NumericT));
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// initialize host
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
@@ -58,8 +66,29 @@ double do_bench(triton::driver::context* context,
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
double nanosec = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc});}, stream);
|
||||
return shift.num_flops() / nanosec * 1e-3;
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(), &beta,
|
||||
dc, shift.ldc(), &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(),
|
||||
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
delete da;
|
||||
delete db;
|
||||
delete dc;
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -86,13 +115,15 @@ int main() {
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
double perf(triton::driver::context *context){
|
||||
return do_bench(context, R, S, B, F, H, W, C, op, layout, ty);
|
||||
perf_t perf(triton::driver::stream *stream){
|
||||
return do_bench(stream, R, S, B, F, H, W, C, op, layout, ty);
|
||||
}
|
||||
};
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs;
|
||||
std::vector<config_t> resnet18 = {
|
||||
std::vector<config_t> resnet18 =
|
||||
{
|
||||
{128, 128, 32, 32, 3, 3, 128, 1, 1},
|
||||
{128, 128, 32, 32, 3, 3, 128, 1, 1},
|
||||
{128, 128, 32, 32, 3, 3, 256, 2, 2},
|
||||
{128, 256, 16, 16, 3, 3, 256, 1, 1},
|
||||
@@ -108,7 +139,11 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
for(config_t c: configs)
|
||||
std::cout << c.repr() << ", " << c.perf(context) << std::endl;
|
||||
triton::driver::stream *stream = triton::driver::stream::create(context);
|
||||
|
||||
for(config_t c: configs){
|
||||
std::string repr = c.repr();
|
||||
perf_t perf = c.perf(stream);
|
||||
std::cout << repr << ", " << perf.triton << ", " << perf.cublas << std::endl;
|
||||
}
|
||||
}
|
||||
|
@@ -99,9 +99,10 @@ inline std::vector<params_t> dot_search_space(bool AT, bool BT) {
|
||||
|
||||
// simple parameter heuristics
|
||||
inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
size_t TM = 128;
|
||||
size_t TN = 128;
|
||||
return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
// size_t TM = 128;
|
||||
// size_t TN = 128;
|
||||
return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1};
|
||||
// return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN});
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -73,6 +73,15 @@ public:
|
||||
// accessors
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
// equivalent GEMM
|
||||
bool AT() const;
|
||||
bool BT() const;
|
||||
size_t M() const;
|
||||
size_t N() const;
|
||||
size_t K() const;
|
||||
size_t lda() const;
|
||||
size_t ldb() const;
|
||||
size_t ldc() const;
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// source
|
||||
|
@@ -781,9 +781,24 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *value = tmap_.at(x->get_value_operand());
|
||||
distributed_tile *mask_tile;
|
||||
if(mask)
|
||||
mask_tile = (distributed_tile*)tmap_.at(ins->get_mask_pred());
|
||||
ptr->for_each([&](indices_t idx){
|
||||
set_mask_insert_pt(idx);
|
||||
StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx));
|
||||
Value *ptr_value = ptr->get_value(idx);
|
||||
Value *value_value = value->get_value(idx);
|
||||
Instruction *store;
|
||||
// if(mask){
|
||||
// Value *pred_value = mask_tile->get_value(idx);
|
||||
// value_value = builder.CreateVectorSplat(1, value_value);
|
||||
// pred_value = builder.CreateVectorSplat(1, pred_value);
|
||||
// Type *ptr_ty = PointerType::get(value_value->getType(), ptr_value->getType()->getPointerAddressSpace());
|
||||
// ptr_value = builder.CreateBitCast(ptr_value, ptr_ty);
|
||||
// store = builder.CreateMaskedStore(value_value, ptr_value, 1, pred_value);
|
||||
// }
|
||||
// else
|
||||
store = new StoreInst(value_value, ptr_value);
|
||||
builder.Insert(store);
|
||||
});
|
||||
}
|
||||
|
@@ -215,7 +215,7 @@ void tune::run(ir::module &mod) {
|
||||
node_t node = *nodes_.begin();
|
||||
if(fragments_[node] == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
@@ -239,7 +239,7 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
|
||||
std::ostringstream oss;
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
|
@@ -129,10 +129,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
int1 checka[)" + AS + R"(] = k > bound;
|
||||
int1 checkb[)" + BS + R"(] = k > bound;
|
||||
@checka a = *pa;
|
||||
@checkb b = *pb;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy*TN + (0 ... TN);
|
||||
|
@@ -180,6 +180,30 @@ size_t shift::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
bool shift::AT() const
|
||||
{ return AT_; }
|
||||
|
||||
bool shift::BT() const
|
||||
{ return BT_; }
|
||||
|
||||
size_t shift::M() const
|
||||
{ return M_; }
|
||||
|
||||
size_t shift::N() const
|
||||
{ return N_; }
|
||||
|
||||
size_t shift::K() const
|
||||
{ return K_; }
|
||||
|
||||
size_t shift::lda() const
|
||||
{ return AT_ ? K_ : M_; }
|
||||
|
||||
size_t shift::ldb() const
|
||||
{ return BT_ ? N_ : K_; }
|
||||
|
||||
size_t shift::ldc() const
|
||||
{ return M_; }
|
||||
|
||||
bool shift::operator <(const base& other) const{
|
||||
auto *y = dynamic_cast<const shift*>(&other);
|
||||
if(!y)
|
||||
@@ -265,10 +289,6 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(30, (int32_t)grid[2]);
|
||||
if(locks_)
|
||||
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
|
||||
}
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -164,8 +164,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
auto mps = passes_0.tune.get_params(tt_module_0);
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
std::mutex mutex;
|
||||
// update_best
|
||||
std::mutex mutex;
|
||||
auto update_best = [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
@@ -211,9 +211,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user