fixed simple FP16 test

This commit is contained in:
Philippe Tillet
2019-08-06 17:14:16 -07:00
parent 6c39cdbace
commit 0e201e18ff
5 changed files with 33 additions and 21 deletions

View File

@@ -49,8 +49,8 @@ class DotOp : public OpKernel {
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
// template
triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8);
dot.enqueue(stream, {&da, &db, &dc});
triton::dnn::dot dot(M, N, K, false, false, "half", "half", "float", 8, 8, 8);
dot.enqueue(stream, {&da, &db, &dc}, triton::dnn::autotuning_t::NO_TUNING);
}
private:

View File

@@ -23,7 +23,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb).T
hresult = np.dot(ha.T, hb.T).T
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)