[dnn/blocksparse/dot] prototype version seems to pass basic test
This commit is contained in:
@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
NumericT alpha = 1;
|
||||
NumericT beta = 0;
|
||||
cublasGemmAlgo_t fastest;
|
||||
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(), &beta,
|
||||
dc, shift.ldc(), &fastest);
|
||||
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
&alpha, da, shift.lda(),
|
||||
db, shift.ldb(),
|
||||
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
// &alpha, da, shift.lda(),
|
||||
// db, shift.ldb(), &beta,
|
||||
// dc, shift.ldc(), &fastest);
|
||||
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
|
||||
// &alpha, da, shift.lda(),
|
||||
// db, shift.ldb(),
|
||||
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
|
||||
perf_t result;
|
||||
result.cublas = tflops(cublas_ns);
|
||||
// result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
delete da;
|
||||
delete db;
|
||||
@@ -133,8 +133,9 @@ int main() {
|
||||
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
|
||||
};
|
||||
for(config_t c: resnet18){
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
|
||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
||||
}
|
||||
}
|
||||
|
||||
// initialize default compute device
|
||||
|
Reference in New Issue
Block a user