postprocess.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import numpy as np
  2. def pline(x1, y1, x2, y2, x, y):
  3. px = x2 - x1
  4. py = y2 - y1
  5. dd = px * px + py * py
  6. u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
  7. dx = x1 + u * px - x
  8. dy = y1 + u * py - y
  9. return dx * dx + dy * dy
  10. def psegment(x1, y1, x2, y2, x, y):
  11. px = x2 - x1
  12. py = y2 - y1
  13. dd = px * px + py * py
  14. u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
  15. dx = x1 + u * px - x
  16. dy = y1 + u * py - y
  17. return dx * dx + dy * dy
  18. def plambda(x1, y1, x2, y2, x, y):
  19. px = x2 - x1
  20. py = y2 - y1
  21. dd = px * px + py * py
  22. return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
  23. def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
  24. nlines, nscores = [], []
  25. for (p, q), score in zip(lines, scores):
  26. start, end = 0, 1
  27. for a, b in nlines:
  28. if (
  29. min(
  30. max(pline(*p, *q, *a), pline(*p, *q, *b)),
  31. max(pline(*a, *b, *p), pline(*a, *b, *q)),
  32. )
  33. > threshold ** 2
  34. ):
  35. continue
  36. lambda_a = plambda(*p, *q, *a)
  37. lambda_b = plambda(*p, *q, *b)
  38. if lambda_a > lambda_b:
  39. lambda_a, lambda_b = lambda_b, lambda_a
  40. lambda_a -= tol
  41. lambda_b += tol
  42. # case 1: skip (if not do_clip)
  43. if start < lambda_a and lambda_b < end:
  44. continue
  45. # not intersect
  46. if lambda_b < start or lambda_a > end:
  47. continue
  48. # cover
  49. if lambda_a <= start and end <= lambda_b:
  50. start = 10
  51. break
  52. # case 2 & 3:
  53. if lambda_a <= start and start <= lambda_b:
  54. start = lambda_b
  55. if lambda_a <= end and end <= lambda_b:
  56. end = lambda_a
  57. if start >= end:
  58. break
  59. if start >= end:
  60. continue
  61. nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
  62. nscores.append(score)
  63. return np.array(nlines), np.array(nscores)
  64. def postprocess_keypoint(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
  65. nlines, nscores = [], []
  66. for (p, q), score in zip(lines, scores):
  67. start, end = 0, 1
  68. for a, b in nlines:
  69. if (
  70. min(
  71. max(pline(*p, *q, *a), pline(*p, *q, *b)),
  72. max(pline(*a, *b, *p), pline(*a, *b, *q)),
  73. )
  74. > threshold ** 2
  75. ):
  76. continue
  77. lambda_a = plambda(*p, *q, *a)
  78. lambda_b = plambda(*p, *q, *b)
  79. if lambda_a > lambda_b:
  80. lambda_a, lambda_b = lambda_b, lambda_a
  81. lambda_a -= tol
  82. lambda_b += tol
  83. # case 1: skip (if not do_clip)
  84. if start < lambda_a and lambda_b < end:
  85. continue
  86. # not intersect
  87. if lambda_b < start or lambda_a > end:
  88. continue
  89. # cover
  90. if lambda_a <= start and end <= lambda_b:
  91. start = 10
  92. break
  93. # case 2 & 3:
  94. if lambda_a <= start and start <= lambda_b:
  95. start = lambda_b
  96. if lambda_a <= end and end <= lambda_b:
  97. end = lambda_a
  98. if start >= end:
  99. break
  100. if start >= end:
  101. continue
  102. nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
  103. nscores.append(min(score[0],score[1]))
  104. return np.array(nlines), np.array(nscores)