basic split-k across warps working for GEMM
This commit is contained in:
@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
// template
|
||||
triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8);
|
||||
triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
|
@@ -23,7 +23,7 @@ def run_dot():
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
hresult = np.dot(ha.T, hb.T).T
|
||||
hresult = np.dot(ha.T, hb).T
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(hresult)
|
||||
@@ -131,6 +131,6 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
#run_dot()
|
||||
run_dot()
|
||||
#run_shift()
|
||||
run_batchnorm()
|
||||
#run_batchnorm()
|
||||
|
Reference in New Issue
Block a user