[CODEGEN][TRANSFORM] some bug-fixes for FP32 einsum
This commit is contained in:
@@ -168,7 +168,7 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.HalfTensor
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
|
Reference in New Issue
Block a user