[dnn/blocksparse/dot] prototype version seems to pass basic test

This commit is contained in:
Philippe Tillet
2019-07-27 21:21:36 -07:00
parent 2a377bc8b1
commit 17cb2db356
18 changed files with 402 additions and 205 deletions

View File

@@ -67,23 +67,23 @@ perf_t do_bench(triton::driver::stream *stream,
stream->write(dc, true, 0, hc);
stream->synchronize();
// benchmark triton
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { shift.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
cublasGemmAlgo_t fastest;
cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(), &beta,
dc, shift.ldc(), &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
&alpha, da, shift.lda(),
db, shift.ldb(),
&beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// NumericT alpha = 1;
// NumericT beta = 0;
// cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(), &beta,
// dc, shift.ldc(), &fastest);
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, shift.AT(), shift.BT(), shift.M(), shift.N(), shift.K(),
// &alpha, da, shift.lda(),
// db, shift.ldb(),
// &beta, dc, shift.ldc(), nullptr, fastest); }, stream);
// result
auto tflops = [&](double nanosec) { return shift.num_flops() / nanosec * 1e-3; };
perf_t result;
result.cublas = tflops(cublas_ns);
// result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
delete da;
delete db;
@@ -133,8 +133,9 @@ int main() {
{128, 1024, 8, 8, 3, 3, 1024, 1, 1}
};
for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD})
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
}
}
// initialize default compute device