[dnn] added Triton-C derivative computations in conv

This commit is contained in:
Philippe Tillet
2019-05-13 00:38:26 -04:00
parent f6fe9492e4
commit 5941501f70
5 changed files with 246 additions and 281 deletions

View File

@@ -10,7 +10,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240;
@@ -77,7 +77,7 @@ int main() {
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
}
}
std::cout << "Pass!" << std::endl;
}