[TEST][DOT] There seems to be a bug in casting tiles before ternary.

Reverting for now
This commit is contained in:
Philippe Tillet
2019-10-25 17:00:53 -04:00
parent b615af2e7e
commit 8bd87fa19d
4 changed files with 8 additions and 6 deletions

View File

@@ -34,7 +34,7 @@ int main() {
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

View File

@@ -147,9 +147,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = static_cast<T>((float)rand()/RAND_MAX);
ha[i] = 1;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = static_cast<T>((float)rand()/RAND_MAX);
hb[i] = 1;
// copy buffer
stream->write(&*da, true, 0, ha);
stream->write(&*db, true, 0, hb);

View File

@@ -25,8 +25,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
c += USEA @ USEB;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = ((bool[SHAPE_A]) k > TK) ? *pa : 0;
b = ((bool[SHAPE_B]) k > TK) ? *pb : 0;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// epilogue
TYPE* pc[TM, TN] = C + rm[:, newaxis] + rn[newaxis, :] * ldc;

View File

@@ -16,7 +16,7 @@ int main() {
for(int nwarps: std::vector<int>{4})
for(bool AT: std::array<bool, 2>{false, true})
for(bool BT: std::array<bool, 2>{false, true}){
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
}
// test
dtype_t dtype;