Browse Source

debug keypoint

RenLiqiang 5 months ago
parent
commit
f9f89a66bc

+ 1 - 0
libs/vision_libs/models/detection/rpn.py

@@ -393,6 +393,7 @@ class RegionProposalNetwork(torch.nn.Module):
 
 
         # 过滤出在 boxes 内的线段
         # 过滤出在 boxes 内的线段
         lines =self.filter_lines_inside_boxes(lines_all, boxes)
         lines =self.filter_lines_inside_boxes(lines_all, boxes)
+        print(f'filter_lines:{lines}')
 
 
 
 
         losses = {}
         losses = {}

+ 1 - 1
models/keypoint/kepointrcnn.py

@@ -278,7 +278,7 @@ def keypointrcnn_resnet18_fpn(
 if __name__ == '__main__':
 if __name__ == '__main__':
     # ins_model = MaskRCNNModel(num_classes=5)
     # ins_model = MaskRCNNModel(num_classes=5)
     keypoint_model = KeypointRCNNModel(num_keypoints=2)
     keypoint_model = KeypointRCNNModel(num_keypoints=2)
-    wts_path='./train_results/20241227_231659/weights/best.pt'
+    # wts_path='./train_results/20241227_231659/weights/best.pt'
 
 
 
 
     # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
     # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'

+ 51 - 46
models/keypoint/keypoint_dataset.py

@@ -1,5 +1,6 @@
 from torch.utils.data.dataset import T_co
 from torch.utils.data.dataset import T_co
 
 
+from libs.vision_libs.utils import draw_keypoints
 from models.base.base_dataset import BaseDataset
 from models.base.base_dataset import BaseDataset
 
 
 import glob
 import glob
@@ -22,7 +23,7 @@ from torch.utils.data import Dataset
 from torch.utils.data.dataloader import default_collate
 from torch.utils.data.dataloader import default_collate
 
 
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
 
 def validate_keypoints(keypoints, image_width, image_height):
 def validate_keypoints(keypoints, image_width, image_height):
     for kp in keypoints:
     for kp in keypoints:
@@ -121,10 +122,12 @@ class KeypointDataset(BaseDataset):
         # return wire_labels, target
         # return wire_labels, target
         target["wires"] = wire_labels
         target["wires"] = wire_labels
 
 
-        target["labels"] = torch.stack(labels)
+        # target["labels"] = torch.stack(labels)
+
         # print(f'labels:{target["labels"]}')
         # print(f'labels:{target["labels"]}')
         # target["boxes"] = line_boxes(target)
         # target["boxes"] = line_boxes(target)
         target["boxes"], keypoints = line_boxes(target)
         target["boxes"], keypoints = line_boxes(target)
+        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
         # keypoints=keypoints/512
         # keypoints=keypoints/512
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
 
@@ -147,59 +150,61 @@ class KeypointDataset(BaseDataset):
         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
         sm.set_array([])
         sm.set_array([])
 
 
-        def imshow(im):
-            plt.close()
-            plt.tight_layout()
-            plt.imshow(im)
-            plt.colorbar(sm, fraction=0.046)
-            plt.xlim([0, im.shape[0]])
-            plt.ylim([im.shape[0], 0])
-
-        def draw_vecl(lines, sline, juncs, junts, fn=None):
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            imshow(io.imread(img_path))
-            if len(lines) > 0 and not (lines[0] == 0).all():
-                for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                    if i > 0 and (lines[i] == lines[0]).all():
-                        break
-                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-            if not (juncs[0] == 0).all():
-                for i, j in enumerate(juncs):
-                    if i > 0 and (i == juncs[0]).all():
-                        break
-                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
-
-
-            img_path = os.path.join(self.img_path, self.imgs[idx])
-            img = PIL.Image.open(img_path).convert('RGB')
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
-                                              colors="yellow", width=1)
-            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
-            plt.show()
 
 
-            plt.show()
-            if fn != None:
-                plt.savefig(fn)
 
 
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        jtyp = target['wires']['jtyp'].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
 
 
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+        keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
+        plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+        plt.show()
+
+
 
 
-        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
-        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
 
 
 
     def show_img(self, img_path):
     def show_img(self, img_path):
         pass
         pass
 
 
-
+def line_boxes(target):
+    boxs = []
+    lpre = target['wires']["lpre"].cpu().numpy()
+    vecl_target = target['wires']["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    keypoints = []
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            keypoints.append([a[1], a[0]])
+            keypoints.append([b[1], b[0]])
+
+            if a[1] > b[1]:
+                ymax = a[1] + 1
+                ymin = b[1] - 1
+            else:
+                ymin = a[1] - 1
+                ymax = b[1] + 1
+            if a[0] > b[0]:
+                xmax = a[0] + 1
+                xmin = b[0] - 1
+            else:
+                xmin = a[0] - 1
+                xmax = b[0] + 1
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    return torch.tensor(boxs), torch.tensor(keypoints)
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    path=r"I:\datasets\wirenet_1000"
+    path=r"\\192.168.50.222/share/lm/Dataset_all"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
-    dataset.show(7)
+    dataset.show(10)

+ 1 - 1
models/keypoint/train.yaml

@@ -1,6 +1,6 @@
 
 
 
 
-dataset_path: I:\datasets\wirenet_1000
+dataset_path: \\192.168.50.222\share\rlq\datasets\250612
 
 
 #train parameters
 #train parameters
 num_classes: 2
 num_classes: 2

+ 4 - 40
models/keypoint/trainer.py

@@ -8,6 +8,7 @@ import torchvision
 from torch.utils.tensorboard import SummaryWriter
 from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
 
+from libs.vision_libs.utils import draw_keypoints
 from models.wirenet.postprocess import postprocess_keypoint
 from models.wirenet.postprocess import postprocess_keypoint
 from torchvision.utils import draw_bounding_boxes
 from torchvision.utils import draw_bounding_boxes
 from torchvision import transforms
 from torchvision import transforms
@@ -33,7 +34,7 @@ def log_losses_to_tensorboard(writer, result, step):
 
 
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
-    model.train1()
+    model.train()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"
     header = f"Epoch: [{epoch}]"
@@ -112,46 +113,9 @@ def show_line(img, pred, epoch, writer):
     # plt.show()
     # plt.show()
 
 
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    keypoint_img = draw_keypoints((img * 255).to(torch.uint8), pred['keypoints'], colors='red', width=3)
 
 
-    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    lines = pred["keypoints"].detach().cpu().numpy()
-    scores = pred["keypoints_scores"].detach().cpu().numpy()
-    # for i in range(1, len(lines)):
-    #     if (lines[i] == lines[0]).all():
-    #         lines = lines[:i]
-    #         scores = scores[:i]
-    #         break
-
-    # postprocess lines to remove overlapped lines
-    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
-    # print(f'nscores:{nscores}')
-
-    for i, t in enumerate([0.5]):
-        plt.gca().set_axis_off()
-        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-        plt.margins(0, 0)
-        for (a, b), s in zip(nlines, nscores):
-            if s < t:
-                continue
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-            # plt.scatter(a[1], a[0], **PLTOPTS)
-            # plt.scatter(b[1], b[0], **PLTOPTS)
-            plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
-            plt.scatter(a[0], a[1], **PLTOPTS)
-            plt.scatter(b[0], b[1], **PLTOPTS)
-        plt.gca().xaxis.set_major_locator(plt.NullLocator())
-        plt.gca().yaxis.set_major_locator(plt.NullLocator())
-        plt.imshow(im.cpu())
-        plt.tight_layout()
-        fig = plt.gcf()
-        fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-            fig.canvas.get_width_height()[::-1] + (3,))
-        plt.close()
-        img2 = transforms.ToTensor()(image_from_plot)
-
-        writer.add_image("output", img2, epoch)
+    writer.add_image("output", keypoint_img, epoch)