basic split-k across warps working for GEMM

This commit is contained in:
Philippe Tillet
2019-08-05 19:33:28 -07:00
parent 899b2b72e1
commit d62e581ab3
12 changed files with 99 additions and 63 deletions

View File

@@ -23,7 +23,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb.T).T
hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)
@@ -131,6 +131,6 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
#run_dot()
run_dot()
#run_shift()
run_batchnorm()
#run_batchnorm()