[dnn/gemm]: fixed leading dimension in transposed variants
This commit is contained in:
@@ -8,12 +8,12 @@
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool AT = true;
|
||||
bool BT = false;
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 1024, N = 1024, K = 1024;
|
||||
int32_t M = 64, N = 128, K = 128;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -33,14 +33,14 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, true);
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// gemm.cpu_ref<float>(rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
|
||||
gemm.enqueue(stream, {da, db, dc}, false);
|
||||
stream->read(dc, true, 0, hc);
|
||||
gemm.cpu_ref<float>(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
Reference in New Issue
Block a user