preparing the field for tensor cores transposes

This commit is contained in:
Philippe Tillet
2019-07-17 13:20:33 -07:00
parent d2e116d057
commit bfa39b8992
9 changed files with 67 additions and 29 deletions

View File

@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// template
triton::dnn::gemm dot(M, N, K, false, true, "fp16", "fp16", 4, 4);
triton::dnn::gemm dot(M, N, K, false, false, "fp16", "fp16", 4, 4);
dot.enqueue(stream, {&da, &db, &dc});
}

View File

@@ -23,7 +23,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb).T
hresult = np.dot(ha.T, hb.T).T
dif = np.abs(result - hresult)
print(hresult)
print(result)