浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
2383d67d99
共有 3 个文件被更改,包括 123 次插入51 次删除
  1. 17 17
      models/wirenet/head.py
  2. 19 0
      models/wirenet/train.py
  3. 87 34
      models/wirenet/wirepoint_rcnn.py

+ 17 - 17
models/wirenet/head.py

@@ -168,9 +168,9 @@ def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
     b = s > 0.5
     lines = []
     score = []
-    print(f"n_batch:{n_batch}")
+    # print(f"n_batch:{n_batch}")
     for i in range(n_batch):
-        print(f"idx:{idx}")
+        # print(f"idx:{idx}")
         p0 = p[idx[i]: idx[i + 1]]
         s0 = s[idx[i]: idx[i + 1]]
         mask = b[idx[i]: idx[i + 1]]
@@ -526,14 +526,14 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
 
 
 def keypointrcnn_inference(x, boxes):
-    print(f'x:{x.shape}')
+    # print(f'x:{x.shape}')
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
     kp_probs = []
     kp_scores = []
 
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
-    print(f'x2:{x2}')
+    # print(f'x2:{x2}')
 
     for xx, bb in zip(x2, boxes):
         kp_prob, scores = heatmaps_to_keypoints(xx, bb)
@@ -775,13 +775,13 @@ class RoIHeads(nn.Module):
 
     def has_wirepoint(self):
         if self.wirepoint_roi_pool is None:
-            print(f'wirepoint_roi_pool is None')
+            # print(f'wirepoint_roi_pool is None')
             return False
         if self.wirepoint_head is None:
-            print(f'wirepoint_head is None')
+            # print(f'wirepoint_head is None')
             return False
         if self.wirepoint_predictor is None:
-            print(f'wirepoint_roi_predictor is None')
+            # print(f'wirepoint_roi_predictor is None')
             return False
         return True
 
@@ -1066,14 +1066,14 @@ class RoIHeads(nn.Module):
             keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
             # tmp = keypoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
-            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
+            # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
             keypoint_features = self.keypoint_head(keypoint_features)
 
-            print(f'keypoint_features:{keypoint_features.shape}')
+            # print(f'keypoint_features:{keypoint_features.shape}')
             tmp = keypoint_features[0][0]
             plt.imshow(tmp.detach().numpy())
             keypoint_logits = self.keypoint_predictor(keypoint_features)
-            print(f'keypoint_logits:{keypoint_logits.shape}')
+            # print(f'keypoint_logits:{keypoint_logits.shape}')
             """
             接wirenet
             """
@@ -1117,23 +1117,23 @@ class RoIHeads(nn.Module):
             else:
                 pos_matched_idxs = None
 
-            print(f'proposals:{len(proposals)}')
+            # print(f'proposals:{len(proposals)}')
             wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
 
             # tmp = keypoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
-            print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
+            # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
             outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
 
             outputs = merge_features(outputs, wirepoint_proposals)
             wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
 
-            print(f'outpust:{outputs.shape}')
+            # print(f'outpust:{outputs.shape}')
 
             wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
 
-            print(f'keypoint_features:{wirepoint_features.shape}')
+            # print(f'keypoint_features:{wirepoint_features.shape}')
             if self.training:
 
                 if targets is None or pos_matched_idxs is None:
@@ -1170,7 +1170,7 @@ def merge_features(features, proposals):
     # 使用 torch.split 按照每个图像的提议数量分割 features
     proposals_count = sum([p.size(0) for p in proposals])
     features_size = features.size(0)
-    print(f'proposals sum:{proposals_count},features batch:{features.size(0)}')
+    # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
     if proposals_count != features_size:
         raise ValueError("The length of proposals must match the batch size of features.")
 
@@ -1179,7 +1179,7 @@ def merge_features(features, proposals):
     for proposal in proposals:
         # 提取当前图像的特征
         current_features = features[start_idx:start_idx + proposal.size(0)]
-        print(f'current_features:{current_features.shape}')
+        # print(f'current_features:{current_features.shape}')
         split_features.append(current_features)
         start_idx += 1
 
@@ -1189,5 +1189,5 @@ def merge_features(features, proposals):
         features_imgs.append(features_per_img)
 
     merged_features = torch.cat(features_imgs, dim=0)
-    print(f' merged_features:{merged_features.shape}')
+    # print(f' merged_features:{merged_features.shape}')
     return merged_features

+ 19 - 0
models/wirenet/train.py

@@ -0,0 +1,19 @@
+def train_epoch(model):
+    pass
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        print(f"{name}:{loss}")
+        total_loss += loss
+
+    return total_loss

+ 87 - 34
models/wirenet/wirepoint_rcnn.py

@@ -26,7 +26,6 @@ from models.wirenet.head import RoIHeads
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from tools import utils
 
-
 FEATURE_DIM = 8
 
 
@@ -119,7 +118,7 @@ class WirepointRCNN(FasterRCNN):
 
         if wirepoint_roi_pool is None:
             wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
-                                                    sampling_ratio=2,)
+                                                    sampling_ratio=2, )
 
         if wirepoint_head is None:
             keypoint_layers = tuple(512 for _ in range(8))
@@ -283,7 +282,7 @@ class WirepointPredictor(nn.Module):
             )
         self.loss = nn.BCEWithLogitsLoss(reduction="none")
 
-    def forward(self, inputs,features, targets=None):
+    def forward(self, inputs, features, targets=None):
 
         # outputs, features = input
         # for out in outputs:
@@ -315,25 +314,24 @@ class WirepointPredictor(nn.Module):
         else:
             self.training = False
             t = {
-                    "junc_coords": torch.zeros(1, 2),
-                    "jtyp": torch.zeros(1, dtype=torch.uint8),
-                    "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                    "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                    "junc_map": torch.zeros([1, 1, 128, 128]),
-                    "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-                }
-            wires_targets=[t for b in range(inputs.size(0))]
-
-            wires_meta={
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
                 "junc_map": torch.zeros([1, 1, 128, 128]),
                 "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
             }
+            wires_targets = [t for b in range(inputs.size(0))]
 
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
 
         T = wires_meta.copy()
         n_jtyp = T["junc_map"].shape[1]
         offset = self.head_off
-        result={}
+        result = {}
         for stack, output in enumerate([inputs]):
             output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
             print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
@@ -396,8 +394,8 @@ class WirepointPredictor(nn.Module):
                         + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
                         + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
                 )
-                .reshape(n_channel, -1, self.n_pts0)
-                .permute(1, 0, 2)
+                    .reshape(n_channel, -1, self.n_pts0)
+                    .permute(1, 0, 2)
             )
             xp = self.pooling(xp)
             print(f'xp.shape:{xp.shape}')
@@ -419,13 +417,11 @@ class WirepointPredictor(nn.Module):
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
         return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
 
-
         # if mode != "training":
         # self.inference(x, idx, jcs, n_batch, ps)
 
         # return result
 
-
     ####deprecated
     # def inference(self,input, idx, jcs, n_batch, ps):
     #     if not self.training:
@@ -565,7 +561,6 @@ class WirepointPredictor(nn.Module):
             return line, label.float(), feat, jcs
 
 
-
 def wirepointrcnn_resnet50_fpn(
         *,
         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
@@ -596,6 +591,21 @@ def wirepointrcnn_resnet50_fpn(
     return model
 
 
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
 if __name__ == '__main__':
     cfg = 'wirenet.yaml'
     cfg = read_yaml(cfg)
@@ -603,30 +613,73 @@ if __name__ == '__main__':
     print(cfg['model']['n_dyn_negl'])
     # net = WirepointPredictor()
 
+
+    if torch.cuda.is_available():
+        device_name = "cuda"
+        torch.backends.cudnn.deterministic = True
+        torch.cuda.manual_seed(0)
+        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    else:
+        print("CUDA is not available")
+
+    device = torch.device(device_name)
+
+
     dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
     train_sampler = torch.utils.data.RandomSampler(dataset)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
     train_collate_fn = utils.collate_fn_wirepoint
     data_loader = torch.utils.data.DataLoader(
-        dataset, batch_sampler=train_batch_sampler, num_workers=10, collate_fn=train_collate_fn
+        dataset, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
     )
-    model = wirepointrcnn_resnet50_fpn()
+    model = wirepointrcnn_resnet50_fpn().to(device)
 
-    for i in cfg['optim']['max_epoch']:
-        model.train()
-        imgs, targets = next(iter(data_loader))
-        pred = model(imgs, targets)
+    optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
 
-    # imgs, targets = next(iter(data_loader))
-    #
-    # model.train()
-    # pred = model(imgs, targets)
-    # print(f'pred:{pred}')
 
-    # result, losses = model(imgs, targets)
-    # print(f'result:{result}')
-    # print(f'pred:{losses}')
+    def move_to_device(data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+
+
+    for i in range(cfg['optim']['max_epoch']):
+        model.train()
+        # imgs, targets = next(iter(data_loader))
+        # loss = model(imgs, targets)
+        # print(loss)
+        # losses = _loss(loss)
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+
+        for imgs, targets in data_loader:
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            print(losses)
+            loss = _loss(losses)
+            print(loss)
+            # 优化器优化模型
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+
+# imgs, targets = next(iter(data_loader))
+#
+# model.train()
+# pred = model(imgs, targets)
+# print(f'pred:{pred}')
+
+# result, losses = model(imgs, targets)
+# print(f'result:{result}')
+# print(f'pred:{losses}')
 '''
 ########### predict#############