Kaynağa Gözat

添加深度值归一化功能

lstrlq 7 ay önce
ebeveyn
işleme
ebb3401c24

+ 3 - 0
models/base/backbone_factory.py

@@ -46,6 +46,9 @@ def get_mobilenet_v3_large_fpn():
     return backbone
 def get_convnext_fpn():
     convnext = models.convnext_base(pretrained=True)
+    # convnext = models.convnext_small(pretrained=True)
+    # convnext = models.convnext_large(pretrained=True)
+
     in_channels_list = [128, 256, 512, 1024]
     backbone_with_fpn = BackboneWithFPN(
         convnext.features,

+ 3 - 3
models/line_detect/111.py

@@ -150,7 +150,7 @@ class Trainer(BaseTrainer):
     def train(self, model, **kwargs):
         dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=8, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
             dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
@@ -234,8 +234,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
     # model = LineNet('line_net.yaml')
     # model = linenet_resnet50_fpn().to(device)
-    # model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
-    model=get_line_net_convnext_fpn(num_classes=2).to(device)
+    model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
+    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_resnet18_fpn()
     trainer = Trainer()
     trainer.train_cfg(model,cfg='./train.yaml')

+ 13 - 0
models/line_detect/dataset_LD.py

@@ -84,6 +84,10 @@ class WirePointDataset(BaseDataset):
         w, h = img.shape[:2]
 
         img=torch.from_numpy(img_3channel).permute(2, 0, 1)
+
+        img=self.zscore_normalize_depth(img)
+
+
         # img=img.transpose(2,0,1)
         # print(f'dataset img shape2:{img.shape}')
 
@@ -101,6 +105,15 @@ class WirePointDataset(BaseDataset):
     def __len__(self):
         return len(self.imgs)
 
+    def zscore_normalize_depth(self,img):
+        depth = img[2]
+        mean = depth.mean()
+        std = depth.std()
+        depth_normalized = (depth - mean) / (std + 1e-8)
+        img_normalized = img.clone()
+        img_normalized[2] = depth_normalized
+        return img_normalized
+
     def read_target(self, item, lbl_path, shape, extra=None):
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:

+ 1 - 1
models/line_detect/line_net.py

@@ -339,7 +339,7 @@ _COMMON_META = {
 }
 
 
-def create_efficientnetv2_backbone(name='efficientnet_v2_l', pretrained=True):
+def create_efficientnetv2_backbone(name='efficientnet_v2_m', pretrained=True):
     # 加载EfficientNetV2模型
     if name == 'efficientnet_v2_s':
         weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None

+ 1 - 0
models/line_detect/roi_heads.py

@@ -1074,6 +1074,7 @@ class RoIHeads(nn.Module):
             if self.training:
                 rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
                 loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
             else:
 
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: \\192.168.50.222\share\lm\04\0510\0510_split
+  datadir: /data/share/zyh/512/init/last
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from: