wireframe.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. #!/usr/bin/env python
  2. """Process Huang's wireframe dataset for L-CNN network
  3. Usage:
  4. dataset/wireframe.py <src> <dst>
  5. dataset/wireframe.py (-h | --help )
  6. Examples:
  7. python dataset/wireframe.py /datadir/wireframe data/wireframe
  8. Arguments:
  9. <src> Original data directory of Huang's wireframe dataset
  10. <dst> Directory of the output
  11. Options:
  12. -h --help Show this screen.
  13. """
  14. import os
  15. import sys
  16. import json
  17. from itertools import combinations
  18. import cv2
  19. import numpy as np
  20. import skimage.draw
  21. import matplotlib.pyplot as plt
  22. from docopt import docopt
  23. from scipy.ndimage import zoom
  24. try:
  25. sys.path.append(".")
  26. sys.path.append("..")
  27. from lcnn.utils import parmap
  28. except Exception:
  29. raise
  30. def inrange(v, shape):
  31. return 0 <= v[0] < shape[0] and 0 <= v[1] < shape[1]
  32. def to_int(x):
  33. return tuple(map(int, x))
  34. def save_heatmap(prefix, image, lines):
  35. im_rescale = (512, 512)
  36. heatmap_scale = (128, 128)
  37. fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1]
  38. jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32)
  39. joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32)
  40. lmap = np.zeros(heatmap_scale, dtype=np.float32)
  41. lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
  42. lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)
  43. lines = lines[:, :, ::-1]
  44. junc = []
  45. jids = {}
  46. def jid(jun):
  47. jun = tuple(jun[:2])
  48. if jun in jids:
  49. return jids[jun]
  50. jids[jun] = len(junc)
  51. junc.append(np.array(jun + (0,)))
  52. return len(junc) - 1
  53. lnid = []
  54. lpos, lneg = [], []
  55. for v0, v1 in lines:
  56. lnid.append((jid(v0), jid(v1)))
  57. lpos.append([junc[jid(v0)], junc[jid(v1)]])
  58. vint0, vint1 = to_int(v0), to_int(v1)
  59. jmap[0][vint0] = 1
  60. jmap[0][vint1] = 1
  61. rr, cc, value = skimage.draw.line_aa(*to_int(v0), *to_int(v1))
  62. lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
  63. for v in junc:
  64. vint = to_int(v[:2])
  65. joff[0, :, vint[0], vint[1]] = v[:2] - vint - 0.5
  66. llmap = zoom(lmap, [0.5, 0.5])
  67. lineset = set([frozenset(l) for l in lnid])
  68. for i0, i1 in combinations(range(len(junc)), 2):
  69. if frozenset([i0, i1]) not in lineset:
  70. v0, v1 = junc[i0], junc[i1]
  71. vint0, vint1 = to_int(v0[:2] / 2), to_int(v1[:2] / 2)
  72. rr, cc, value = skimage.draw.line_aa(*vint0, *vint1)
  73. lneg.append([v0, v1, i0, i1, np.average(np.minimum(value, llmap[rr, cc]))])
  74. assert len(lneg) != 0
  75. lneg.sort(key=lambda l: -l[-1])
  76. junc = np.array(junc, dtype=np.float32)
  77. Lpos = np.array(lnid, dtype=np.int)
  78. Lneg = np.array([l[2:4] for l in lneg][:4000], dtype=np.int)
  79. lpos = np.array(lpos, dtype=np.float32)
  80. lneg = np.array([l[:2] for l in lneg[:2000]], dtype=np.float32)
  81. image = cv2.resize(image, im_rescale)
  82. # plt.subplot(131), plt.imshow(lmap)
  83. # plt.subplot(132), plt.imshow(image)
  84. # for i0, i1 in Lpos:
  85. # plt.scatter(junc[i0][1] * 4, junc[i0][0] * 4)
  86. # plt.scatter(junc[i1][1] * 4, junc[i1][0] * 4)
  87. # plt.plot([junc[i0][1] * 4, junc[i1][1] * 4], [junc[i0][0] * 4, junc[i1][0] * 4])
  88. # plt.subplot(133), plt.imshow(lmap)
  89. # for i0, i1 in Lneg[:150]:
  90. # plt.plot([junc[i0][1], junc[i1][1]], [junc[i0][0], junc[i1][0]])
  91. # plt.show()
  92. # For junc, lpos, and lneg that stores the junction coordinates, the last
  93. # dimension is (y, x, t), where t represents the type of that junction. In
  94. # the wireframe dataset, t is always zero.
  95. np.savez_compressed(
  96. f"{prefix}_label.npz",
  97. aspect_ratio=image.shape[1] / image.shape[0],
  98. jmap=jmap, # [J, H, W] Junction heat map
  99. joff=joff, # [J, 2, H, W] Junction offset within each pixel
  100. lmap=lmap, # [H, W] Line heat map with anti-aliasing
  101. junc=junc, # [Na, 3] Junction coordinate
  102. Lpos=Lpos, # [M, 2] Positive lines represented with junction indices
  103. Lneg=Lneg, # [M, 2] Negative lines represented with junction indices
  104. lpos=lpos, # [Np, 2, 3] Positive lines represented with junction coordinates
  105. lneg=lneg, # [Nn, 2, 3] Negative lines represented with junction coordinates
  106. )
  107. cv2.imwrite(f"{prefix}.png", image)
  108. # plt.imshow(jmap[0])
  109. # plt.savefig("/tmp/1jmap0.jpg")
  110. # plt.imshow(jmap[1])
  111. # plt.savefig("/tmp/2jmap1.jpg")
  112. # plt.imshow(lmap)
  113. # plt.savefig("/tmp/3lmap.jpg")
  114. # plt.imshow(Lmap[2])
  115. # plt.savefig("/tmp/4ymap.jpg")
  116. # plt.imshow(jwgt[0])
  117. # plt.savefig("/tmp/5jwgt.jpg")
  118. # plt.cla()
  119. # plt.imshow(jmap[0])
  120. # for i in range(8):
  121. # plt.quiver(
  122. # 8 * jmap[0] * cdir[i] * np.cos(2 * math.pi / 16 * i),
  123. # 8 * jmap[0] * cdir[i] * np.sin(2 * math.pi / 16 * i),
  124. # units="xy",
  125. # angles="xy",
  126. # scale_units="xy",
  127. # scale=1,
  128. # minlength=0.01,
  129. # width=0.1,
  130. # zorder=10,
  131. # color="w",
  132. # )
  133. # plt.savefig("/tmp/6cdir.jpg")
  134. # plt.cla()
  135. # plt.imshow(lmap)
  136. # plt.quiver(
  137. # 2 * lmap * np.cos(ldir),
  138. # 2 * lmap * np.sin(ldir),
  139. # units="xy",
  140. # angles="xy",
  141. # scale_units="xy",
  142. # scale=1,
  143. # minlength=0.01,
  144. # width=0.1,
  145. # zorder=10,
  146. # color="w",
  147. # )
  148. # plt.savefig("/tmp/7ldir.jpg")
  149. # plt.cla()
  150. # plt.imshow(jmap[1])
  151. # plt.quiver(
  152. # 8 * jmap[1] * np.cos(tdir),
  153. # 8 * jmap[1] * np.sin(tdir),
  154. # units="xy",
  155. # angles="xy",
  156. # scale_units="xy",
  157. # scale=1,
  158. # minlength=0.01,
  159. # width=0.1,
  160. # zorder=10,
  161. # color="w",
  162. # )
  163. # plt.savefig("/tmp/8tdir.jpg")
  164. def main():
  165. args = docopt(__doc__)
  166. data_root = args["<src>"]
  167. data_output = args["<dst>"]
  168. os.makedirs(data_output, exist_ok=True)
  169. for batch in ["train", "valid"]:
  170. anno_file = os.path.join(data_root, f"{batch}.json")
  171. with open(anno_file, "r") as f:
  172. dataset = json.load(f)
  173. def handle(data):
  174. im = cv2.imread(os.path.join(data_root, "images", data["filename"]))
  175. prefix = data["filename"].split(".")[0]
  176. lines = np.array(data["lines"]).reshape(-1, 2, 2)
  177. os.makedirs(os.path.join(data_output, batch), exist_ok=True)
  178. lines0 = lines.copy()
  179. lines1 = lines.copy()
  180. lines1[:, :, 0] = im.shape[1] - lines1[:, :, 0]
  181. lines2 = lines.copy()
  182. lines2[:, :, 1] = im.shape[0] - lines2[:, :, 1]
  183. lines3 = lines.copy()
  184. lines3[:, :, 0] = im.shape[1] - lines3[:, :, 0]
  185. lines3[:, :, 1] = im.shape[0] - lines3[:, :, 1]
  186. path = os.path.join(data_output, batch, prefix)
  187. save_heatmap(f"{path}_0", im[::, ::], lines0)
  188. if batch != "valid":
  189. save_heatmap(f"{path}_1", im[::, ::-1], lines1)
  190. save_heatmap(f"{path}_2", im[::-1, ::], lines2)
  191. save_heatmap(f"{path}_3", im[::-1, ::-1], lines3)
  192. print("Finishing", os.path.join(data_output, batch, prefix))
  193. parmap(handle, dataset, 16)
  194. if __name__ == "__main__":
  195. main()