Selaa lähdekoodia

修改line_predictor为decoder

lstrlq 4 kuukautta sitten
vanhempi
commit
ecec69f04d

+ 4 - 3
models/line_detect/heads/head_losses.py

@@ -411,8 +411,8 @@ def heatmaps_to_lines(maps, rois):
     line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
     line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
 
-    line_maps=maps[:,1]
-
+    # line_maps=maps[:,1]
+    line_maps = maps.squeeze(1)
 
     for i in range(len(rois)):
         line_roi_map = line_maps[i].unsqueeze(0)
@@ -503,7 +503,8 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # line_logits = line_logits.view(N * K, H * W)
     # print(f'line_logits[valid]:{line_logits[valid].shape}')
     print(f'loss1 line_logits:{line_logits.shape}')
-    line_logits = line_logits[:,1,:,:]
+    # line_logits = line_logits[:,1,:,:]
+    line_logits = line_logits.squeeze(1)
     print(f'loss2 line_logits:{line_logits.shape}')
 
     # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])

+ 10 - 2
models/line_detect/line_detect.py

@@ -180,7 +180,8 @@ class LineDetect(BaseDetectionNet):
 
         if line_predictor is None and detect_line:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(in_channels=256)
+        #     line_predictor = LinePredictor(in_channels=256)
+            line_predictor = ArcUnet(Bottleneck)
 
         if point_head is None and detect_point:
             layers = tuple(num_points for _ in range(8))
@@ -478,7 +479,14 @@ def linedetect_newresnet152fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,
+                       min_size=size,max_size=size,
+                       num_points=num_points, rpn_anchor_generator=anchor_generator,
+                       box_roi_pool=roi_pooler,
+                       detect_point=False,
+                       detect_line=True,
+                       detect_arc=False,
+                       **kwargs)
 
     return model
 

+ 3 - 1
models/line_detect/loi_heads.py

@@ -1422,7 +1422,9 @@ class RoIHeads(nn.Module):
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
-        cs_features=features['0']
+        # cs_features=features['0']
+
+        cs_features = features
         # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
         #
         # if len(filtered_proposals) > 0:

+ 4 - 4
models/line_detect/train.yaml

@@ -1,10 +1,10 @@
 io:
   logdir: train_results
 #  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
-  datadir: /data/share/zyh/arc/a_dataset
+#  datadir: /data/share/zyh/arc/a_dataset
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
-#  datadir: /data/share/rlq/datasets/250718caisegangban
+  datadir: /data/share/rlq/datasets/250718caisegangban_hunhe
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/arc_datasets_100
@@ -22,8 +22,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-#  augmentation: True
-  augmentation: False
+  augmentation: True
+#  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 1 - 1
models/line_detect/train_demo.py

@@ -21,7 +21,7 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     model = linedetect_newresnet152fpn(num_points=3)
-    # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
+    model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250728_080143/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()
 

+ 1 - 1
models/line_detect/trainer.py

@@ -263,7 +263,7 @@ class Trainer(BaseTrainer):
         self.init_params(**kwargs)
 
         dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=False, data_type=self.data_type, dataset_type='val')
+        dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)