Przeglądaj źródła

重构show_datasets

RenLiqiang 7 miesięcy temu
rodzic
commit
ceaacdd464
1 zmienionych plików z 21 dodań i 31 usunięć
  1. 21 31
      models/ins_detect/test_datasets.py

+ 21 - 31
models/ins_detect/test_datasets.py

@@ -454,46 +454,36 @@ def show_dataset():
 
     dataset = MaskRCNNDataset(dataset_path=r'\\192.168.50.222\share\rlq\datasets\bangcai2', transforms=transforms,
                               dataset_type='train')
-    dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=utils.collate_fn)
-    imgs, targets = next(iter(dataloader))
-
-    mask = np.array(targets[0]['masks'][0])
-    masks=targets[0]['masks'].to(torch.bool)
-    boxes = targets[0]['boxes']
-    print(f'boxes:{boxes}')
-    # mask[mask == 255] = 1
-    # img = np.array(imgs[2].permute(1, 2, 0)) * 255
-    img=np.array(imgs[0])
+    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=utils.collate_fn)
+    for imgs, targets in dataloader:
+        masks=targets[0]['masks']
+        boxes = targets[0]['boxes']
+        print(f'boxes:{boxes}')
+        # mask[mask == 255] = 1
+        # img = np.array(imgs[2].permute(1, 2, 0)) * 255
+        show_boxes_masks( imgs, boxes,masks)
+
+
+def show_boxes_masks(imgs, boxes,masks):
+    img = np.array(imgs[0])
     img = img.astype(np.uint8)
+    masks=masks.to(torch.bool)
     print(f'img shape:{img.shape}')
     print(f'img shape:{img.shape}')
-    print(f'mask:{mask.shape}')
-    # print(f'target:{targets}')
-    # print(f'imgs:{imgs[0]}')
-    # print(f'cv2 img shape:{np.array(imgs[0]).shape}')
-    # cv2.imshow('cv2 img',img)
-    # cv2.imshow('cv2 mask', mask)
-    # plt.imshow('mask',mask)
-    mask_3channel = cv2.merge([np.zeros_like(mask), np.zeros_like(mask), mask])
-    # cv2.imshow('mask_3channel',mask_3channel)
-    print(f'mask_3channel:{mask_3channel.shape}')
-    # masked_image = cv2.addWeighted(img, 1, mask_3channel, 0.6, 0)
-    # cv2.imshow('cv2 mask img', masked_image)
-    img_tensor=torch.tensor(imgs[0],dtype=torch.uint8)
-
-    boxed_img=draw_bounding_boxes(img_tensor,boxes).permute(1, 2, 0).contiguous()
-    masked_img=draw_segmentation_masks(img_tensor,masks).permute(1, 2, 0).contiguous()
-
+    # print(f'mask:{mask.shape}')
+    # mask_3channel = cv2.merge([np.zeros_like(masks[0]), np.zeros_like(masks[0]), masks[0]])
+    # print(f'mask_3channel:{mask_3channel.shape}')
+    img_tensor = torch.tensor(imgs[0], dtype=torch.uint8)
+    boxed_img = draw_bounding_boxes(img_tensor, boxes).permute(1, 2, 0).contiguous()
+    masked_img = draw_segmentation_masks(img_tensor, masks).permute(1, 2, 0).contiguous()
     plt.imshow(imgs[0].permute(1, 2, 0))
     # plt.imshow(mask, cmap='Reds', alpha=0.5)
     plt.imshow(masked_img, cmap='Reds', alpha=0.3)
-    plt.imshow(boxed_img,cmap='Greens', alpha=0.5)
-    # drawn_boxes = draw_bounding_boxes((imgs[2] * 255).to(torch.uint8), boxes, colors="red", width=5)
-    # plt.imshow(drawn_boxes.permute(1, 2, 0))
-    # show(drawn_boxes)
+    plt.imshow(boxed_img, cmap='Greens', alpha=0.5)
     plt.show()
     cv2.waitKey(0)
 
+
 def test_cluster(img_path):
     test_img = PIL.Image.open(img_path)
     w, h = test_img.size