[dnn] added Triton-C derivative computations in conv
This commit is contained in:
@@ -10,7 +10,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
@@ -77,7 +77,7 @@ int main() {
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
Reference in New Issue
Block a user