xue50 3 months ago
parent
commit
8c208a87b7

+ 0 - 6
libs/vision_libs/models/detection/generalized_rcnn.py

@@ -108,12 +108,6 @@ class GeneralizedRCNN(nn.Module):
         losses = {}
         losses.update(detector_losses)
         losses.update(proposal_losses)
-        # print(f'1{detector_losses.keys()}')
-        # print(f'2{proposal_losses.keys()}')
-        # print(f'123{losses.keys()}')
-        print(f'self.training:{self.training}')
-        print(f'123{losses}')
-
 
         if torch.jit.is_scripting():
             if not self._has_warned:

+ 0 - 3
models/line_detect/roi_heads.py

@@ -1007,8 +1007,6 @@ class RoIHeads(nn.Module):
         else:
             self.training = False
 
-        print(f'self.training:{self.training}')
-
         if targets is not None:
             for t in targets:
                 # TODO: https://github.com/pytorch/pytorch/issues/26731
@@ -1168,6 +1166,5 @@ class RoIHeads(nn.Module):
                     r["keypoints"] = keypoint_prob
                     r["keypoints_scores"] = kps
             losses.update(loss_keypoint)
-        print(f'losses111:{losses.keys()}')
 
         return result, losses

+ 75 - 77
train——line_rcnn.py

@@ -1,4 +1,3 @@
-
 # 根据LCNN写的train    2025/2/7
 '''
 #!/usr/bin/env python3
@@ -161,6 +160,8 @@ if __name__ == "__main__":
     main()
 '''
 
+
+# 2025/2/9
 import os
 from typing import Optional, Any
 
@@ -178,9 +179,14 @@ import matplotlib as mpl
 from skimage import io
 
 from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+from torchvision.utils import draw_bounding_boxes
+from models.wirenet.postprocess import postprocess
+from torchvision import transforms
+from collections import OrderedDict
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
+
 def _loss(losses):
     total_loss = 0
     for i in losses.keys():
@@ -215,51 +221,50 @@ def imshow(im):
     plt.ylim([im.shape[0], 0])
 
 
-def _plot_samples(self, i, index, result, targets, prefix):
-    fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
-    img = io.imread(fn)
-    imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
-
-    def draw_vecl(lines, sline, juncs, junts, fn):
-        imshow(img)
-        if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
-                    break
-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
-        if not (juncs[0] == 0).all():
-            for i, j in enumerate(juncs):
-                if i > 0 and (i == juncs[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
-        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
-            for i, j in enumerate(junts):
-                if i > 0 and (i == junts[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
-        plt.savefig(fn), plt.close()
-
-    junc = targets[i]["junc"].cpu().numpy() * 4
-    jtyp = targets[i]["jtyp"].cpu().numpy()
-    juncs = junc[jtyp == 0]
-    junts = junc[jtyp == 1]
-    rjuncs = result["juncs"][i].cpu().numpy() * 4
-    rjunts = None
-    if "junts" in result:
-        rjunts = result["junts"][i].cpu().numpy() * 4
-
-    lpre = targets[i]["lpre"].cpu().numpy() * 4
-    vecl_target = targets[i]["lpre_label"].cpu().numpy()
-    vecl_result = result["lines"][i].cpu().numpy() * 4
-    score = result["score"][i].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-
-    draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
-    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
-
-    img = cv2.imread(f"{prefix}_vecl_a.jpg")
-    img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
-    self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
+def show_line(img, pred,  epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred[1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.8]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
 
 
 if __name__ == '__main__':
@@ -292,7 +297,6 @@ if __name__ == '__main__':
     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])
 
-
     def move_to_device(data, device):
         if isinstance(data, (list, tuple)):
             return type(data)(move_to_device(item, device) for item in data)
@@ -304,19 +308,30 @@ if __name__ == '__main__':
             return data  # 对于非张量类型的数据不做任何改变
 
 
+    # def writer_loss(writer, losses, epoch):
+    #     try:
+    #         for key, value in losses.items():
+    #             if key == 'loss_wirepoint':
+    #                 for subdict in losses['loss_wirepoint']['losses']:
+    #                     for subkey, subvalue in subdict.items():
+    #                         writer.add_scalar(f'loss_wirepoint/{subkey}',
+    #                                           subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+    #                                           epoch)
+    #             elif isinstance(value, torch.Tensor):
+    #                 writer.add_scalar(key, value.item(), epoch)
+    #     except Exception as e:
+    #         print(f"TensorBoard logging error: {e}")
     def writer_loss(writer, losses, epoch):
         try:
             for key, value in losses.items():
                 if key == 'loss_wirepoint':
-                    # ?? wirepoint ??????
                     for subdict in losses['loss_wirepoint']['losses']:
                         for subkey, subvalue in subdict.items():
-                            # ?? .item() ?????
-                            writer.add_scalar(f'loss_wirepoint/{subkey}',
+                            writer.add_scalar(f'loss/{subkey}',
                                               subvalue.item() if hasattr(subvalue, 'item') else subvalue,
                                               epoch)
                 elif isinstance(value, torch.Tensor):
-                    writer.add_scalar(key, value.item(), epoch)
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
         except Exception as e:
             print(f"TensorBoard logging error: {e}")
 
@@ -326,37 +341,20 @@ if __name__ == '__main__':
         model.train()
 
         for imgs, targets in data_loader_train:
-
             losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-            # print(type(losses))
             # print(losses)
             loss = _loss(losses)
-            # print(loss)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             writer_loss(writer, losses, epoch)
 
-            model.eval()
-            with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                    pred = model(move_to_device(imgs, device))
-                    # print(f"perd:{pred}")
-                    break
-
-                    # print(f"perd:{pred}")
-
-                # if batch_idx == 0:
-                #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
-                #     H = pred["wires"]
-                #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
-
-# imgs, targets = next(iter(data_loader))
-#
-# model.train()
-# pred = model(imgs, targets)
-# print(f'pred:{pred}')
-
-# result, losses = model(imgs, targets)
-# print(f'result:{result}')
-# print(f'pred:{losses}')
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))
+                if batch_idx == 0:
+                    show_line(imgs[0], pred, epoch, writer)
+                break
+
+