wireframe.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #!/usr/bin/env python
  2. """Process Huang's wireframe dataset for L-CNN network
  3. Usage:
  4. dataset/wireframe.py <src> <dst>
  5. dataset/wireframe.py (-h | --help )
  6. Examples:
  7. python dataset/wireframe.py /datadir/wireframe data/wireframe
  8. Arguments:
  9. <src> Original data directory of Huang's wireframe dataset
  10. <dst> Directory of the output
  11. Options:
  12. -h --help Show this screen.
  13. """
  14. import os
  15. import sys
  16. import json
  17. from itertools import combinations
  18. import cv2
  19. import numpy as np
  20. import skimage.draw
  21. import matplotlib.pyplot as plt
  22. from docopt import docopt
  23. from scipy.ndimage import zoom
  24. try:
  25. sys.path.append(".")
  26. sys.path.append("..")
  27. from lcnn.utils import parmap
  28. except Exception:
  29. raise
  30. def inrange(v, shape):
  31. return 0 <= v[0] < shape[0] and 0 <= v[1] < shape[1]
  32. def to_int(x):
  33. return tuple(map(int, x))
  34. def save_heatmap(prefix, image, lines):
  35. im_rescale = (512, 512)
  36. heatmap_scale = (128, 128)
  37. fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1]
  38. jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32)
  39. joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32)
  40. lmap = np.zeros(heatmap_scale, dtype=np.float32)
  41. lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
  42. lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)
  43. lines = lines[:, :, ::-1]
  44. junc = []
  45. jids = {}
  46. def jid(jun):
  47. jun = tuple(jun[:2])
  48. if jun in jids:
  49. return jids[jun]
  50. jids[jun] = len(junc)
  51. junc.append(np.array(jun + (0,)))
  52. return len(junc) - 1
  53. lnid = []
  54. lpos, lneg = [], []
  55. for v0, v1 in lines:
  56. lnid.append((jid(v0), jid(v1)))
  57. lpos.append([junc[jid(v0)], junc[jid(v1)]])
  58. vint0, vint1 = to_int(v0), to_int(v1)
  59. jmap[0][vint0] = 1
  60. jmap[0][vint1] = 1
  61. rr, cc, value = skimage.draw.line_aa(*to_int(v0), *to_int(v1))
  62. lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
  63. for v in junc:
  64. vint = to_int(v[:2])
  65. joff[0, :, vint[0], vint[1]] = v[:2] - vint - 0.5
  66. llmap = zoom(lmap, [0.5, 0.5])
  67. lineset = set([frozenset(l) for l in lnid])
  68. for i0, i1 in combinations(range(len(junc)), 2):
  69. if frozenset([i0, i1]) not in lineset:
  70. v0, v1 = junc[i0], junc[i1]
  71. vint0, vint1 = to_int(v0[:2] / 2), to_int(v1[:2] / 2)
  72. rr, cc, value = skimage.draw.line_aa(*vint0, *vint1)
  73. lneg.append([v0, v1, i0, i1, np.average(np.minimum(value, llmap[rr, cc]))])
  74. # assert np.sum((v0 - v1) ** 2) > 0.01
  75. assert len(lneg) != 0
  76. lneg.sort(key=lambda l: -l[-1])
  77. junc = np.array(junc, dtype=np.float32)
  78. Lpos = np.array(lnid, dtype=np.int)
  79. Lneg = np.array([l[2:4] for l in lneg][:4000], dtype=np.int)
  80. lpos = np.array(lpos, dtype=np.float32)
  81. lneg = np.array([l[:2] for l in lneg[:2000]], dtype=np.float32)
  82. image = cv2.resize(image, im_rescale)
  83. # plt.subplot(131), plt.imshow(lmap)
  84. # plt.subplot(132), plt.imshow(image)
  85. # for i0, i1 in Lpos:
  86. # plt.scatter(junc[i0][1] * 4, junc[i0][0] * 4)
  87. # plt.scatter(junc[i1][1] * 4, junc[i1][0] * 4)
  88. # plt.plot([junc[i0][1] * 4, junc[i1][1] * 4], [junc[i0][0] * 4, junc[i1][0] * 4])
  89. # plt.subplot(133), plt.imshow(lmap)
  90. # for i0, i1 in Lneg[:150]:
  91. # plt.plot([junc[i0][1], junc[i1][1]], [junc[i0][0], junc[i1][0]])
  92. # plt.show()
  93. np.savez_compressed(
  94. f"{prefix}_label.npz",
  95. aspect_ratio=image.shape[1] / image.shape[0],
  96. jmap=jmap, # [J, H, W]
  97. joff=joff, # [J, 2, H, W]
  98. lmap=lmap, # [H, W]
  99. junc=junc, # [Na, 3]
  100. Lpos=Lpos, # [M, 2]
  101. Lneg=Lneg, # [M, 2]
  102. lpos=lpos, # [Np, 2, 3] (y, x, t) for the last dim
  103. lneg=lneg, # [Nn, 2, 3]
  104. )
  105. cv2.imwrite(f"{prefix}.png", image)
  106. # plt.imshow(jmap[0])
  107. # plt.savefig("/tmp/1jmap0.jpg")
  108. # plt.imshow(jmap[1])
  109. # plt.savefig("/tmp/2jmap1.jpg")
  110. # plt.imshow(lmap)
  111. # plt.savefig("/tmp/3lmap.jpg")
  112. # plt.imshow(Lmap[2])
  113. # plt.savefig("/tmp/4ymap.jpg")
  114. # plt.imshow(jwgt[0])
  115. # plt.savefig("/tmp/5jwgt.jpg")
  116. # plt.cla()
  117. # plt.imshow(jmap[0])
  118. # for i in range(8):
  119. # plt.quiver(
  120. # 8 * jmap[0] * cdir[i] * np.cos(2 * math.pi / 16 * i),
  121. # 8 * jmap[0] * cdir[i] * np.sin(2 * math.pi / 16 * i),
  122. # units="xy",
  123. # angles="xy",
  124. # scale_units="xy",
  125. # scale=1,
  126. # minlength=0.01,
  127. # width=0.1,
  128. # zorder=10,
  129. # color="w",
  130. # )
  131. # plt.savefig("/tmp/6cdir.jpg")
  132. # plt.cla()
  133. # plt.imshow(lmap)
  134. # plt.quiver(
  135. # 2 * lmap * np.cos(ldir),
  136. # 2 * lmap * np.sin(ldir),
  137. # units="xy",
  138. # angles="xy",
  139. # scale_units="xy",
  140. # scale=1,
  141. # minlength=0.01,
  142. # width=0.1,
  143. # zorder=10,
  144. # color="w",
  145. # )
  146. # plt.savefig("/tmp/7ldir.jpg")
  147. # plt.cla()
  148. # plt.imshow(jmap[1])
  149. # plt.quiver(
  150. # 8 * jmap[1] * np.cos(tdir),
  151. # 8 * jmap[1] * np.sin(tdir),
  152. # units="xy",
  153. # angles="xy",
  154. # scale_units="xy",
  155. # scale=1,
  156. # minlength=0.01,
  157. # width=0.1,
  158. # zorder=10,
  159. # color="w",
  160. # )
  161. # plt.savefig("/tmp/8tdir.jpg")
  162. def main():
  163. args = docopt(__doc__)
  164. data_root = args["<src>"]
  165. data_output = args["<dst>"]
  166. os.makedirs(data_output, exist_ok=True)
  167. for batch in ["train", "valid"]:
  168. anno_file = os.path.join(data_root, f"{batch}.json")
  169. with open(anno_file, "r") as f:
  170. dataset = json.load(f)
  171. def handle(data):
  172. im = cv2.imread(os.path.join(data_root, "images", data["filename"]))
  173. prefix = data["filename"].split(".")[0]
  174. lines = np.array(data["lines"]).reshape(-1, 2, 2)
  175. os.makedirs(os.path.join(data_output, batch), exist_ok=True)
  176. lines0 = lines.copy()
  177. lines1 = lines.copy()
  178. lines1[:, :, 0] = im.shape[1] - lines1[:, :, 0]
  179. lines2 = lines.copy()
  180. lines2[:, :, 1] = im.shape[0] - lines2[:, :, 1]
  181. lines3 = lines.copy()
  182. lines3[:, :, 0] = im.shape[1] - lines3[:, :, 0]
  183. lines3[:, :, 1] = im.shape[0] - lines3[:, :, 1]
  184. path = os.path.join(data_output, batch, prefix)
  185. save_heatmap(f"{path}_0", im[::, ::], lines0)
  186. if batch != "valid":
  187. save_heatmap(f"{path}_1", im[::, ::-1], lines1)
  188. save_heatmap(f"{path}_2", im[::-1, ::], lines2)
  189. save_heatmap(f"{path}_3", im[::-1, ::-1], lines3)
  190. print("Finishing", os.path.join(data_output, batch, prefix))
  191. parmap(handle, dataset, 16)
  192. if __name__ == "__main__":
  193. main()