|
|
@@ -49,6 +49,48 @@ from torchvision.transforms import functional as F
|
|
|
import torch
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+def fit_circle(points):
|
|
|
+ """
|
|
|
+ Fit a circle to a set of points (at least 3).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ points: torch.Tensor æ numpy array, shape (N, 2)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ center (cx, cy), radius r
|
|
|
+ """
|
|
|
+ # å¦ææ¯ torch.Tensorï¼å
转为 numpy
|
|
|
+ if isinstance(points, torch.Tensor):
|
|
|
+ if points.dim() == 3:
|
|
|
+ points = points[0] # 廿 batch 维度
|
|
|
+ points = points.detach().cpu().numpy()
|
|
|
+
|
|
|
+ if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
|
|
|
+ raise ValueError(f"Expected points shape (N, 2), got {points.shape}")
|
|
|
+
|
|
|
+ x = points[:, 0].astype(float)
|
|
|
+ y = points[:, 1].astype(float)
|
|
|
+
|
|
|
+ # ç¡®ä¿ A æ¯äºç»´æ°ç»
|
|
|
+ A = np.column_stack((x, y, np.ones_like(x))) # ä½¿ç¨ column_stack ä»£æ¿ stack å¯è½æ´æ¸
æ°
|
|
|
+ b = -(x ** 2 + y ** 2)
|
|
|
+
|
|
|
+ try:
|
|
|
+ sol, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
|
|
|
+ except np.linalg.LinAlgError as e:
|
|
|
+ print(f"Linear algebra error occurred: {e}")
|
|
|
+ raise ValueError("Could not fit circle to points.")
|
|
|
+
|
|
|
+ D, E, F = sol
|
|
|
+
|
|
|
+ cx = -D / 2.0
|
|
|
+ cy = -E / 2.0
|
|
|
+ r = np.sqrt(cx ** 2 + cy ** 2 - F)
|
|
|
+
|
|
|
+ return (cx, cy), r
|
|
|
+
|
|
|
# 由低到高蓝黄红
|
|
|
def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
|
|
|
"""
|
|
|
@@ -232,6 +274,42 @@ class Trainer(BaseTrainer):
|
|
|
# img_tensor = np.transpose(img_tensor)
|
|
|
self.writer.add_image('z-out-arc', arcs, global_step=epoch)
|
|
|
|
|
|
+ if 'circles' in result:
|
|
|
+ # points=result['circles']
|
|
|
+ # points=points.squeeze(1)
|
|
|
+ ppp=result['circles']
|
|
|
+ bbb=result['boxes']
|
|
|
+ print(f'boxes shape:{bbb.shape}')
|
|
|
+ print(f'ppp:{ppp.shape}')
|
|
|
+ points = result['circles']
|
|
|
+ points = points.squeeze(1)
|
|
|
+ print(f'points shape:{points.shape}')
|
|
|
+
|
|
|
+ circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
|
|
|
+ circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
|
|
|
+
|
|
|
+
|
|
|
+ if isinstance(points, torch.Tensor):
|
|
|
+ points = points.cpu().numpy()
|
|
|
+
|
|
|
+ for point_set in points:
|
|
|
+ center, radius = fit_circle(point_set)
|
|
|
+ cx, cy = map(int, center)
|
|
|
+
|
|
|
+ circle_image = cv2.circle(circle_image, (cx, cy), int(radius), (0, 0, 255), 2)
|
|
|
+
|
|
|
+ for point in point_set:
|
|
|
+ px, py = map(int, point)
|
|
|
+ circle_image = cv2.circle(circle_image, (px, py), 3, (0, 255, 255), -1)
|
|
|
+
|
|
|
+ img_rgb = cv2.cvtColor(circle_image, cv2.COLOR_BGR2RGB)
|
|
|
+ img_tensor = img_rgb.transpose((2, 0, 1)) # HWC -> CHW
|
|
|
+ img_tensor = img_tensor / 255.0 # å½ä¸åå° [0, 1]
|
|
|
+
|
|
|
+
|
|
|
+ # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
|
|
|
+ self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
|
|
|
+
|
|
|
# cv2.imshow('arc', img_rgb)
|
|
|
# cv2.waitKey(1000000)
|
|
|
|