浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
b70a8b4946
共有 44 个文件被更改,包括 126 次插入109 次删除
  1. 14 14
      models/dataset_tool.py
  2. 97 26
      models/wirenet/head.py
  3. 二进制
      models/wirenet/logs/events.out.tfevents.1733792580.Li.52792.0
  4. 二进制
      models/wirenet/logs/events.out.tfevents.1733792677.Li.32744.0
  5. 二进制
      models/wirenet/logs/events.out.tfevents.1733792849.Li.52844.0
  6. 二进制
      models/wirenet/logs/events.out.tfevents.1733792868.Li.5368.0
  7. 二进制
      models/wirenet/logs/events.out.tfevents.1733793317.Li.104660.0
  8. 二进制
      models/wirenet/logs/events.out.tfevents.1733793375.Li.78948.0
  9. 二进制
      models/wirenet/logs/events.out.tfevents.1733793393.Li.5400.0
  10. 二进制
      models/wirenet/logs/events.out.tfevents.1733793423.Li.124644.0
  11. 二进制
      models/wirenet/logs/events.out.tfevents.1733794274.Li.133412.0
  12. 二进制
      models/wirenet/logs/events.out.tfevents.1733794915.Li.8636.0
  13. 二进制
      models/wirenet/logs/events.out.tfevents.1733795124.Li.130792.0
  14. 二进制
      models/wirenet/logs/events.out.tfevents.1733795192.Li.90196.0
  15. 二进制
      models/wirenet/logs/events.out.tfevents.1733795957.Li.120256.0
  16. 二进制
      models/wirenet/logs/events.out.tfevents.1733795987.Li.35360.0
  17. 二进制
      models/wirenet/logs/events.out.tfevents.1733796468.Li.101736.0
  18. 二进制
      models/wirenet/logs/events.out.tfevents.1733796499.Li.25124.0
  19. 二进制
      models/wirenet/logs/events.out.tfevents.1733796707.Li.130044.0
  20. 二进制
      models/wirenet/logs/events.out.tfevents.1733796729.Li.73460.0
  21. 二进制
      models/wirenet/logs/events.out.tfevents.1733796800.Li.97688.0
  22. 二进制
      models/wirenet/logs/events.out.tfevents.1733796835.Li.85584.0
  23. 二进制
      models/wirenet/logs/events.out.tfevents.1733796889.Li.116656.0
  24. 二进制
      models/wirenet/logs/events.out.tfevents.1733799015.Li.13032.0
  25. 二进制
      models/wirenet/logs/events.out.tfevents.1733799093.Li.64224.0
  26. 二进制
      models/wirenet/logs/events.out.tfevents.1733799170.Li.15848.0
  27. 二进制
      models/wirenet/logs/events.out.tfevents.1733799203.Li.126828.0
  28. 二进制
      models/wirenet/logs/events.out.tfevents.1733799239.Li.30468.0
  29. 二进制
      models/wirenet/logs/events.out.tfevents.1733799262.Li.43732.0
  30. 二进制
      models/wirenet/logs/events.out.tfevents.1733799341.Li.12300.0
  31. 二进制
      models/wirenet/logs/events.out.tfevents.1733800143.Li.59332.0
  32. 二进制
      models/wirenet/logs/events.out.tfevents.1733800199.Li.73984.0
  33. 二进制
      models/wirenet/logs/events.out.tfevents.1733800235.Li.83000.0
  34. 二进制
      models/wirenet/logs/events.out.tfevents.1733800269.Li.68072.0
  35. 二进制
      models/wirenet/logs/events.out.tfevents.1733800323.Li.14540.0
  36. 二进制
      models/wirenet/logs/events.out.tfevents.1733800946.Li.41800.0
  37. 二进制
      models/wirenet/logs/events.out.tfevents.1733801481.Li.61988.0
  38. 二进制
      models/wirenet/logs/events.out.tfevents.1733802629.Li.91500.0
  39. 二进制
      models/wirenet/logs/events.out.tfevents.1733802649.Li.63868.0
  40. 二进制
      models/wirenet/logs/events.out.tfevents.1733802777.Li.60488.0
  41. 二进制
      models/wirenet/logs/events.out.tfevents.1733803010.Li.4304.0
  42. 1 0
      models/wirenet/train.py
  43. 1 4
      models/wirenet/wirepoint_dataset.py
  44. 13 65
      models/wirenet/wirepoint_rcnn.py

+ 14 - 14
models/dataset_tool.py

@@ -283,24 +283,24 @@ def read_masks_from_pixels_wire(lbl_path, shape):
         lines = json.load(reader)
         mask_points = []
         for line in lines["segmentations"]:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line["data"]
+            # mask = torch.zeros((h, w), dtype=torch.uint8)
+            # parts = line["data"]
             # print(f'parts:{parts}')
             cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
             labels.append(cls)
-            x_array = parts[0::2]
-            y_array = parts[1::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
+            # x_array = parts[0::2]
+            # y_array = parts[1::2]
+            # 
+            # for x, y in zip(x_array, y_array):
+            #     x = float(x)
+            #     y = float(y)
+            #     mask_points.append((int(y * h), int(x * w)))
+
+            # for p in mask_points:
+            #     mask[p] = 1
+            # masks.append(mask)
     reader.close()
-    return labels, masks
+    return labels
 
 
 def adjacency_matrix(n, link):  # 邻接矩阵

+ 97 - 26
models/wirenet/head.py

@@ -163,6 +163,7 @@ def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
 def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
     result = {}
     result["wires"] = {}
+    print(f"ps1:{ps}")
     p = torch.cat(ps)
     s = torch.sigmoid(input)
     b = s > 0.5
@@ -1164,30 +1165,100 @@ class RoIHeads(nn.Module):
         return result, losses
 
 
+# def merge_features(features, proposals):
+#     # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
+#
+#     # 使用 torch.split 按照每个图像的提议数量分割 features
+#     proposals_count = sum([p.size(0) for p in proposals])
+#     features_size = features.size(0)
+#     # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
+#     if proposals_count != features_size:
+#         raise ValueError("The length of proposals must match the batch size of features.")
+#
+#     split_features = []
+#     start_idx = 0
+#     print(f"proposals:{proposals}")
+#     for proposal in proposals:
+#         # 提取当前图像的特征
+#         current_features = features[start_idx:start_idx + proposal.size(0)]
+#         # print(f'current_features:{current_features.shape}')
+#         split_features.append(current_features)
+#         start_idx += 1
+#
+#     features_imgs = []
+#     for features_per_img in split_features:
+#         features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
+#         features_imgs.append(features_per_img)
+#
+#     merged_features = torch.cat(features_imgs, dim=0)
+#     # print(f' merged_features:{merged_features.shape}')
+#     return merged_features
+
 def merge_features(features, proposals):
-    # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
-
-    # 使用 torch.split 按照每个图像的提议数量分割 features
-    proposals_count = sum([p.size(0) for p in proposals])
-    features_size = features.size(0)
-    # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
-    if proposals_count != features_size:
-        raise ValueError("The length of proposals must match the batch size of features.")
-
-    split_features = []
-    start_idx = 0
-    for proposal in proposals:
-        # 提取当前图像的特征
-        current_features = features[start_idx:start_idx + proposal.size(0)]
-        # print(f'current_features:{current_features.shape}')
-        split_features.append(current_features)
-        start_idx += 1
-
-    features_imgs = []
-    for features_per_img in split_features:
-        features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
-        features_imgs.append(features_per_img)
-
-    merged_features = torch.cat(features_imgs, dim=0)
-    # print(f' merged_features:{merged_features.shape}')
-    return merged_features
+    def diagnose_input(features, proposals):
+        """诊断输入数据"""
+        print("Input Diagnostics:")
+        print(f"Features type: {type(features)}, shape: {features.shape}")
+        print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
+        for i, p in enumerate(proposals):
+            print(f"Proposal {i} shape: {p.shape}")
+
+    def validate_inputs(features, proposals):
+        """验证输入的有效性"""
+        if features is None or proposals is None:
+            raise ValueError("Features or proposals cannot be None")
+
+        proposals_count = sum([p.size(0) for p in proposals])
+        features_size = features.size(0)
+
+        if proposals_count != features_size:
+            raise ValueError(
+                f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
+            )
+
+    def safe_max_reduction(features_per_img):
+        """安全的最大值压缩"""
+        if features_per_img.numel() == 0:
+            return torch.zeros_like(features_per_img).unsqueeze(0)
+
+        try:
+            # 沿着第0维求最大值,保持维度
+            max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
+            return max_features
+        except Exception as e:
+            print(f"Max reduction error: {e}")
+            return features_per_img.unsqueeze(0)
+
+    try:
+        # 诊断输入(可选)
+        diagnose_input(features, proposals)
+
+        # 验证输入
+        validate_inputs(features, proposals)
+
+        # 分割特征
+        split_features = []
+        start_idx = 0
+
+        for proposal in proposals:
+            # 提取当前图像的特征
+            current_features = features[start_idx:start_idx + proposal.size(0)]
+            split_features.append(current_features)
+            start_idx += proposal.size(0)
+
+        # 每张图像特征压缩
+        features_imgs = []
+        for features_per_img in split_features:
+            compressed_features = safe_max_reduction(features_per_img)
+            features_imgs.append(compressed_features)
+
+        # 合并特征
+        merged_features = torch.cat(features_imgs, dim=0)
+
+        return merged_features
+
+    except Exception as e:
+        print(f"Error in merge_features: {e}")
+        # 返回原始特征或None
+        return features
+

二进制
models/wirenet/logs/events.out.tfevents.1733792580.Li.52792.0


二进制
models/wirenet/logs/events.out.tfevents.1733792677.Li.32744.0


二进制
models/wirenet/logs/events.out.tfevents.1733792849.Li.52844.0


二进制
models/wirenet/logs/events.out.tfevents.1733792868.Li.5368.0


二进制
models/wirenet/logs/events.out.tfevents.1733793317.Li.104660.0


二进制
models/wirenet/logs/events.out.tfevents.1733793375.Li.78948.0


二进制
models/wirenet/logs/events.out.tfevents.1733793393.Li.5400.0


二进制
models/wirenet/logs/events.out.tfevents.1733793423.Li.124644.0


二进制
models/wirenet/logs/events.out.tfevents.1733794274.Li.133412.0


二进制
models/wirenet/logs/events.out.tfevents.1733794915.Li.8636.0


二进制
models/wirenet/logs/events.out.tfevents.1733795124.Li.130792.0


二进制
models/wirenet/logs/events.out.tfevents.1733795192.Li.90196.0


二进制
models/wirenet/logs/events.out.tfevents.1733795957.Li.120256.0


二进制
models/wirenet/logs/events.out.tfevents.1733795987.Li.35360.0


二进制
models/wirenet/logs/events.out.tfevents.1733796468.Li.101736.0


二进制
models/wirenet/logs/events.out.tfevents.1733796499.Li.25124.0


二进制
models/wirenet/logs/events.out.tfevents.1733796707.Li.130044.0


二进制
models/wirenet/logs/events.out.tfevents.1733796729.Li.73460.0


二进制
models/wirenet/logs/events.out.tfevents.1733796800.Li.97688.0


二进制
models/wirenet/logs/events.out.tfevents.1733796835.Li.85584.0


二进制
models/wirenet/logs/events.out.tfevents.1733796889.Li.116656.0


二进制
models/wirenet/logs/events.out.tfevents.1733799015.Li.13032.0


二进制
models/wirenet/logs/events.out.tfevents.1733799093.Li.64224.0


二进制
models/wirenet/logs/events.out.tfevents.1733799170.Li.15848.0


二进制
models/wirenet/logs/events.out.tfevents.1733799203.Li.126828.0


二进制
models/wirenet/logs/events.out.tfevents.1733799239.Li.30468.0


二进制
models/wirenet/logs/events.out.tfevents.1733799262.Li.43732.0


二进制
models/wirenet/logs/events.out.tfevents.1733799341.Li.12300.0


二进制
models/wirenet/logs/events.out.tfevents.1733800143.Li.59332.0


二进制
models/wirenet/logs/events.out.tfevents.1733800199.Li.73984.0


二进制
models/wirenet/logs/events.out.tfevents.1733800235.Li.83000.0


二进制
models/wirenet/logs/events.out.tfevents.1733800269.Li.68072.0


二进制
models/wirenet/logs/events.out.tfevents.1733800323.Li.14540.0


二进制
models/wirenet/logs/events.out.tfevents.1733800946.Li.41800.0


二进制
models/wirenet/logs/events.out.tfevents.1733801481.Li.61988.0


二进制
models/wirenet/logs/events.out.tfevents.1733802629.Li.91500.0


二进制
models/wirenet/logs/events.out.tfevents.1733802649.Li.63868.0


二进制
models/wirenet/logs/events.out.tfevents.1733802777.Li.60488.0


二进制
models/wirenet/logs/events.out.tfevents.1733803010.Li.4304.0


+ 1 - 0
models/wirenet/train.py

@@ -17,3 +17,4 @@ def _loss(losses):
         total_loss += loss
 
     return total_loss
+

+ 1 - 4
models/wirenet/wirepoint_dataset.py

@@ -101,18 +101,15 @@ class WirePointDataset(BaseDataset):
             "line_map": torch.tensor(wire['line_map']["content"]),
         }
 
-        h, w = shape
         labels = []
-        masks = []
         if self.target_type == 'polygon':
             labels, masks = read_masks_from_txt_wire(lbl_path, shape)
         elif self.target_type == 'pixel':
-            labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
 
         # print(torch.stack(masks).shape)    # [线段数, 512, 512]
         target = {}
         target["labels"] = torch.stack(labels)
-        target["masks"] = torch.stack(masks)
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels

+ 13 - 65
models/wirenet/wirepoint_rcnn.py

@@ -347,20 +347,11 @@ class WirepointPredictor(nn.Module):
                     "lmap": lmap.sigmoid(),
                     "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                 }
-                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
-                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
-                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
 
         h = result["preds"]
-        print(f'features shape:{features.shape}')
+        # print(f'features shape:{features.shape}')
         x = self.fc1(features)
-
-        # print(f'x:{x.shape}')
-
         n_batch, n_channel, row, col = x.shape
-
-        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
-
         xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
 
         for i, meta in enumerate(wires_targets):
@@ -408,12 +399,9 @@ class WirepointPredictor(nn.Module):
         x, y = torch.cat(xs), torch.cat(ys)
         f = torch.cat(fs)
         x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-
-        # print("Weight dtype:", self.fc2.weight.dtype)
+        print(f"pstest{ps}")
         x = torch.cat([x, f], 1)
-        # print("Input dtype:", x.dtype)
         x = x.to(dtype=torch.float32)
-        # print("Input dtype1:", x.dtype)
         x = self.fc2(x).flatten()
 
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
@@ -424,47 +412,6 @@ class WirepointPredictor(nn.Module):
 
         # return result
 
-    ####deprecated
-    # def inference(self,input, idx, jcs, n_batch, ps):
-    #     if not self.training:
-    #         p = torch.cat(ps)
-    #         s = torch.sigmoid(input)
-    #         b = s > 0.5
-    #         lines = []
-    #         score = []
-    #         print(f"n_batch:{n_batch}")
-    #         for i in range(n_batch):
-    #             print(f"idx:{idx}")
-    #             p0 = p[idx[i]: idx[i + 1]]
-    #             s0 = s[idx[i]: idx[i + 1]]
-    #             mask = b[idx[i]: idx[i + 1]]
-    #             p0 = p0[mask]
-    #             s0 = s0[mask]
-    #             if len(p0) == 0:
-    #                 lines.append(torch.zeros([1, self.n_out_line, 2, 2], device=p.device))
-    #                 score.append(torch.zeros([1, self.n_out_line], device=p.device))
-    #             else:
-    #                 arg = torch.argsort(s0, descending=True)
-    #                 p0, s0 = p0[arg], s0[arg]
-    #                 lines.append(p0[None, torch.arange(self.n_out_line) % len(p0)])
-    #                 score.append(s0[None, torch.arange(self.n_out_line) % len(s0)])
-    #             for j in range(len(jcs[i])):
-    #                 if len(jcs[i][j]) == 0:
-    #                     jcs[i][j] = torch.zeros([self.n_out_junc, 2], device=p.device)
-    #                 jcs[i][j] = jcs[i][j][
-    #                     None, torch.arange(self.n_out_junc) % len(jcs[i][j])
-    #                 ]
-    #         result["preds"]["lines"] = torch.cat(lines)
-    #         result["preds"]["score"] = torch.cat(score)
-    #         result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-    #
-    #         if len(jcs[i]) > 1:
-    #             result["preds"]["junts"] = torch.cat(
-    #                 [jcs[i][1] for i in range(n_batch)]
-    #             )
-    #     if self.training:
-    #         del result["preds"]
-
     def sample_lines(self, meta, jmap, joff):
         with torch.no_grad():
             junc = meta["junc_coords"]  # [N, 2]
@@ -631,21 +578,21 @@ if __name__ == '__main__':
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
     train_collate_fn = utils.collate_fn_wirepoint
     data_loader_train = torch.utils.data.DataLoader(
-        dataset_train, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
     )
 
     dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
     val_sampler = torch.utils.data.RandomSampler(dataset_val)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    val_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
     val_collate_fn = utils.collate_fn_wirepoint
     data_loader_val = torch.utils.data.DataLoader(
-        dataset_val, batch_sampler=val_batch_sampler, num_workers=4, collate_fn=val_collate_fn
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
     )
-    
+
     model = wirepointrcnn_resnet50_fpn().to(device)
 
-    optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])
 
 
@@ -682,11 +629,12 @@ if __name__ == '__main__':
             optimizer.step()
             writer_loss(writer, losses)
 
-        model.eval()
-        with torch.no_grad():
-            for imgs, targets in dataset_val:
-                pred = model(move_to_device(imgs, device), move_to_device(targets, device))
-                
+            model.eval()
+            with torch.no_grad():
+                for imgs, targets in data_loader_val:
+                    print(111)
+                    pred = model(move_to_device(imgs, device), move_to_device(targets, device))
+                    print(f"pred:{pred}")
 
 # imgs, targets = next(iter(data_loader))
 #