diff --git a/examples/isaac-tools.cpp b/examples/isaac-tools.cpp index 48c1821e1..e1eac0dd1 100755 --- a/examples/isaac-tools.cpp +++ b/examples/isaac-tools.cpp @@ -260,7 +260,7 @@ void benchmark_gemm(Metric const & metric, sc::driver::Context& ctx, sc::driver: sc::driver::Buffer B(ctx, K*N*dtsize); std::vector times; - times.push_back(bench([&](){ sc::GEMM(device, stream, dtype, dtype, AT, BT, M, N, K, 0, lda, 0, ldb, 0, ldc, alpha, A, B, beta, C, 1., 1., 1., NULL, (sc::templates::GEMM*)generator); }, [&](){ stream.synchronize(); }, device)); + times.push_back(bench([&](){ sc::GEMM(device, stream, dtype, dtype, AT, BT, M, N, K, 0, lda, 0, ldb, 0, ldc, alpha, A, B, beta, C, 1., 1., 1., NULL, (sc::templates::GEMM*)generator, 10); }, [&](){ stream.synchronize(); }, device)); if(sc::driver::dispatch::cublasinit()){ cublasGemmAlgo_t fastest; sc::driver::cublasGemm(dtype, stream, cuAT, cuBT, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, &fastest); @@ -294,7 +294,7 @@ void benchmark_conv(Metric const & metric, sc::driver::Context& ctx, sc::driver: sc::driver::Buffer F(ctx, K*C/vect_c*T*R*S*sc::size_of(in_dtype)); std::vector times; - times.push_back(bench([&](){ sc::CONV(device, stream, in_dtype, out_dtype, N, K, M, P, Q, C, T, R, S, D, H, W, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, upsample_d, upsample_h, upsample_w, I, F, &O, 1, NULL, activation, 0., 1., 1., {1.}, 1., sc::NoResidual, Zk, crop_z_m0, crop_z_m1, crop_z_p0, crop_z_p1, crop_z_q0, crop_z_q1, NULL, (sc::templates::Conv*)generator); }, [&](){ stream.synchronize(); }, device)); + times.push_back(bench([&](){ sc::CONV(device, stream, in_dtype, out_dtype, N, K, M, P, Q, C, T, R, S, D, H, W, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, upsample_d, upsample_h, upsample_w, I, F, &O, 1, NULL, activation, 0., 1., 1., {1.}, 1., sc::NoResidual, Zk, crop_z_m0, crop_z_m1, crop_z_p0, crop_z_p1, crop_z_q0, crop_z_q1, NULL, (sc::templates::Conv*)generator, 10); }, [&](){ stream.synchronize(); }, device)); // if(sc::driver::dispatch::cudnninit()) // times.push_back(bench([&](){ sc::driver::cudnnConv(out_dtype, stream, D, H, W, N, K, M, P, Q, C, T, R, S, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, alpha, I, F, beta, O); }, [&](){ stream.synchronize(); }, device)); print_results(times, {str(N), str(K), str(M), str(P), str(Q), str(C), str(T), str(R), str(S)}, metric.cmp(), [&](double tsec){ return metric.conv(M, P, Q, K, N, C, T, R, S, tsec);});