[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for
TensorFlow Graph Execution
This commit is contained in:
@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
|
||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||
|
||||
a = tf.placeholder(tf.float16, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float16, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float16, c_shape, name="e")
|
||||
feed_dict = { a: A.astype(np.float16),
|
||||
b: B.astype(np.float16),
|
||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
feed_dict = { a: A.astype(np.float32),
|
||||
b: B.astype(np.float32),
|
||||
e: E }
|
||||
|
||||
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||
|
Reference in New Issue
Block a user