[triton/dnn/conv] merged optimizations branch
- Added forward/backward support for strided convolution - Added support for bias - Added support for reduction splitting
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
|
||||
int main() {
|
||||
@@ -52,8 +53,8 @@ int main() {
|
||||
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, context->device());
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user