Ver código fonte

修改predict_demo.py,修正尺寸问题

xue50 7 meses atrás
pai
commit
d5a7ead8d6
1 arquivos alterados com 4 adições e 5 exclusões
  1. 4 5
      models/line_detect/predict.py

+ 4 - 5
models/line_detect/predict.py

@@ -28,7 +28,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 def box_line_(imgs, pred):  # 默认置信度
     im = imgs.permute(1, 2, 0).cpu().numpy()
-    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([2000, 2000])
+    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
     # print(f'111:{len(lines)}')
@@ -114,8 +114,7 @@ def show_all(imgs, pred, threshold, save_path):
 
     boxes = pred[0]['boxes'].cpu().numpy()
     box_scores = pred[0]['scores'].cpu().numpy()
-    # lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
-    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([2000, 2000])
+    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
     for i in range(1, len(lines)):
@@ -331,7 +330,7 @@ class Predict:
         im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
         # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
         if im.shape != (512, 512, 3):
-            im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+            im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
         img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
         t_end = time.time()
         print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
@@ -425,7 +424,7 @@ class Predict1:
         im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
         # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
         if im.shape != (512, 512, 3):
-            im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+            im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
         img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
         t_end = time.time()
         print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")