york.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #!/usr/bin/env python3
  2. """Process YorkUrban dataset for L-CNN network
  3. Usage:
  4. dataset/york.py <src> <dst>
  5. dataset/york.py (-h | --help )
  6. Examples:
  7. python dataset/york.py /datadir/york data/york
  8. Arguments:
  9. <src> Original data directory of YorkUrban
  10. <dst> Directory of the output
  11. Options:
  12. -h --help Show this screen.
  13. """
  14. import os
  15. import sys
  16. import glob
  17. import json
  18. import os.path as osp
  19. from itertools import combinations
  20. import cv2
  21. import numpy as np
  22. import skimage.draw
  23. import matplotlib.pyplot as plt
  24. from docopt import docopt
  25. from scipy.io import loadmat
  26. from scipy.ndimage import zoom
  27. try:
  28. sys.path.append(".")
  29. sys.path.append("..")
  30. from lcnn.utils import parmap
  31. except Exception:
  32. raise
  33. def inrange(v, shape):
  34. return 0 <= v[0] < shape[0] and 0 <= v[1] < shape[1]
  35. def to_int(x):
  36. return tuple(map(int, x))
  37. def save_heatmap(prefix, image, lines):
  38. im_rescale = (512, 512)
  39. heatmap_scale = (128, 128)
  40. fy, fx = heatmap_scale[1] / image.shape[0], heatmap_scale[0] / image.shape[1]
  41. jmap = np.zeros((1,) + heatmap_scale, dtype=np.float32)
  42. joff = np.zeros((1, 2) + heatmap_scale, dtype=np.float32)
  43. lmap = np.zeros(heatmap_scale, dtype=np.float32)
  44. lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
  45. lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)
  46. lines = lines[:, :, ::-1]
  47. junc = []
  48. jids = {}
  49. def jid(jun):
  50. jun = tuple(jun[:2])
  51. if jun in jids:
  52. return jids[jun]
  53. jids[jun] = len(junc)
  54. junc.append(np.array(jun + (0,)))
  55. return len(junc) - 1
  56. lnid = []
  57. lpos, lneg = [], []
  58. for v0, v1 in lines:
  59. lnid.append((jid(v0), jid(v1)))
  60. lpos.append([junc[jid(v0)], junc[jid(v1)]])
  61. vint0, vint1 = to_int(v0), to_int(v1)
  62. jmap[0][vint0] = 1
  63. jmap[0][vint1] = 1
  64. rr, cc, value = skimage.draw.line_aa(*to_int(v0), *to_int(v1))
  65. lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
  66. for v in junc:
  67. vint = to_int(v[:2])
  68. joff[0, :, vint[0], vint[1]] = v[:2] - vint - 0.5
  69. llmap = zoom(lmap, [0.5, 0.5])
  70. lineset = set([frozenset(l) for l in lnid])
  71. for i0, i1 in combinations(range(len(junc)), 2):
  72. if frozenset([i0, i1]) not in lineset:
  73. v0, v1 = junc[i0], junc[i1]
  74. vint0, vint1 = to_int(v0[:2] / 2), to_int(v1[:2] / 2)
  75. rr, cc, value = skimage.draw.line_aa(*vint0, *vint1)
  76. lneg.append([v0, v1, i0, i1, np.average(np.minimum(value, llmap[rr, cc]))])
  77. # assert np.sum((v0 - v1) ** 2) > 0.01
  78. assert len(lneg) != 0
  79. lneg.sort(key=lambda l: -l[-1])
  80. junc = np.array(junc, dtype=np.float32)
  81. Lpos = np.array(lnid, dtype=np.int)
  82. Lneg = np.array([l[2:4] for l in lneg][:4000], dtype=np.int)
  83. lpos = np.array(lpos, dtype=np.float32)
  84. lneg = np.array([l[:2] for l in lneg[:2000]], dtype=np.float32)
  85. image = cv2.resize(image, im_rescale)
  86. # plt.subplot(131), plt.imshow(lmap)
  87. # plt.subplot(132), plt.imshow(image)
  88. # for i0, i1 in Lpos:
  89. # plt.scatter(junc[i0][1] * 4, junc[i0][0] * 4)
  90. # plt.scatter(junc[i1][1] * 4, junc[i1][0] * 4)
  91. # plt.plot([junc[i0][1] * 4, junc[i1][1] * 4], [junc[i0][0] * 4, junc[i1][0] * 4])
  92. # plt.subplot(133), plt.imshow(lmap)
  93. # for i0, i1 in Lneg[:150]:
  94. # plt.plot([junc[i0][1], junc[i1][1]], [junc[i0][0], junc[i1][0]])
  95. # plt.show()
  96. np.savez_compressed(
  97. f"{prefix}_label.npz",
  98. aspect_ratio=image.shape[1] / image.shape[0],
  99. jmap=jmap, # [J, H, W]
  100. joff=joff, # [J, 2, H, W]
  101. lmap=lmap, # [H, W]
  102. junc=junc, # [Na, 3]
  103. Lpos=Lpos, # [M, 2]
  104. Lneg=Lneg, # [M, 2]
  105. lpos=lpos, # [Np, 2, 3] (y, x, t) for the last dim
  106. lneg=lneg, # [Nn, 2, 3]
  107. )
  108. cv2.imwrite(f"{prefix}.png", image)
  109. # plt.imshow(jmap[0])
  110. # plt.savefig("/tmp/1jmap0.jpg")
  111. # plt.imshow(jmap[1])
  112. # plt.savefig("/tmp/2jmap1.jpg")
  113. # plt.imshow(lmap)
  114. # plt.savefig("/tmp/3lmap.jpg")
  115. # plt.imshow(Lmap[2])
  116. # plt.savefig("/tmp/4ymap.jpg")
  117. # plt.imshow(jwgt[0])
  118. # plt.savefig("/tmp/5jwgt.jpg")
  119. # plt.cla()
  120. # plt.imshow(jmap[0])
  121. # for i in range(8):
  122. # plt.quiver(
  123. # 8 * jmap[0] * cdir[i] * np.cos(2 * math.pi / 16 * i),
  124. # 8 * jmap[0] * cdir[i] * np.sin(2 * math.pi / 16 * i),
  125. # units="xy",
  126. # angles="xy",
  127. # scale_units="xy",
  128. # scale=1,
  129. # minlength=0.01,
  130. # width=0.1,
  131. # zorder=10,
  132. # color="w",
  133. # )
  134. # plt.savefig("/tmp/6cdir.jpg")
  135. # plt.cla()
  136. # plt.imshow(lmap)
  137. # plt.quiver(
  138. # 2 * lmap * np.cos(ldir),
  139. # 2 * lmap * np.sin(ldir),
  140. # units="xy",
  141. # angles="xy",
  142. # scale_units="xy",
  143. # scale=1,
  144. # minlength=0.01,
  145. # width=0.1,
  146. # zorder=10,
  147. # color="w",
  148. # )
  149. # plt.savefig("/tmp/7ldir.jpg")
  150. # plt.cla()
  151. # plt.imshow(jmap[1])
  152. # plt.quiver(
  153. # 8 * jmap[1] * np.cos(tdir),
  154. # 8 * jmap[1] * np.sin(tdir),
  155. # units="xy",
  156. # angles="xy",
  157. # scale_units="xy",
  158. # scale=1,
  159. # minlength=0.01,
  160. # width=0.1,
  161. # zorder=10,
  162. # color="w",
  163. # )
  164. # plt.savefig("/tmp/8tdir.jpg")
  165. def main():
  166. args = docopt(__doc__)
  167. data_root = args["<src>"]
  168. data_output = args["<dst>"]
  169. os.makedirs(data_output, exist_ok=True)
  170. dataset = sorted(glob.glob(osp.join(data_root, "*/*.jpg")))
  171. def handle(iname):
  172. prefix = osp.split(iname)[1].replace(".jpg", "")
  173. im = cv2.imread(iname)
  174. mat = loadmat(iname.replace(".jpg", "LinesAndVP.mat"))
  175. lines = np.array(mat["lines"]).reshape(-1, 2, 2)
  176. path = osp.join(data_output, prefix)
  177. save_heatmap(f"{path}", im[::, ::], lines)
  178. print(f"Finishing {path}")
  179. parmap(handle, dataset)
  180. if __name__ == "__main__":
  181. main()