diff --git a/baselines/common/tile_images.py b/baselines/common/tile_images.py new file mode 100644 index 0000000..929da89 --- /dev/null +++ b/baselines/common/tile_images.py @@ -0,0 +1,23 @@ +import numpy as np + +def tile_images(img_nhwc): + """ + Tile N images into one big PxQ image + (P,Q) are chosen to be as close as possible, and if N + is square, then P=Q. + + input: img_nhwc, list or array of images, ndim=4 once turned into array + n = batch index, h = height, w = width, c = channel + returns: + bigim_HWc, ndarray with ndim=3 + """ + img_nhwc = np.asarray(img_nhwc) + N, h, w, c = img_nhwc.shape + H = int(np.ceil(np.sqrt(N))) + W = int(np.ceil(float(N)/H)) + img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) + img_HWhwc = img_nhwc.reshape(H, W, h, w, c) + img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) + img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) + return img_Hh_Ww_c +