Ver código fonte

修改resne layer1 downsample

lstrlq 5 meses atrás
pai
commit
cb668aabd0

+ 2 - 2
models/base/high_reso_resnet.py

@@ -186,10 +186,10 @@ class ResNet(nn.Module):
 
 
         self.encoder0 = nn.Sequential(
-            nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1, bias=False),
+            nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1,stride=1, bias=False),
             self._norm_layer(self.inplanes),
             nn.ReLU(inplace=True),
-            nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
         )
         # self.encoder0 = nn.Sequential(
         #     nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False),

+ 2 - 2
models/base/transforms.py

@@ -481,7 +481,7 @@ def get_transforms(augmention=True):
 
     if augmention:
 
-        transforms_list.append(ColorJitter())
+        # transforms_list.append(ColorJitter())
         transforms_list.append(RandomGrayscale(0.1))
 
         transforms_list.append(GaussianBlur())
@@ -490,7 +490,7 @@ def get_transforms(augmention=True):
         transforms_list.append(RandomVerticalFlip(0.5))
         # transforms_list.append(RandomPerspective())
         transforms_list.append(RandomRotation(degrees=15))
-        transforms_list.append(RandomResize(512, 2048))
+        # transforms_list.append(RandomResize(512, 2048))
 
         # transforms_list.append(RandomCrop((512,512)))
 

+ 17 - 3
models/line_detect/line_dataset.py

@@ -1,3 +1,5 @@
+import imageio
+import numpy as np
 from torch.utils.data.dataset import T_co
 
 from libs.vision_libs.utils import draw_keypoints
@@ -48,9 +50,21 @@ class LineDataset(BaseDataset):
     def __getitem__(self, index) -> T_co:
         img_path = os.path.join(self.img_path, self.imgs[index])
 
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
+        if self.data_type == 'tiff':
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
+            img = imageio.v3.imread(img_path)[:,:,0]
+            print(f'img shape:{img.shape}')
+            w, h = img.shape[:2]
+            img=img.reshape(w,h,1)
+            img_3channel = np.zeros((w, h, 3), dtype=img.dtype)
+            img_3channel[:, :, 2] = img[:, :, 0]
+
+
+            img = torch.from_numpy(img_3channel).permute(2, 1, 0)
+        else:
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+            img = PIL.Image.open(img_path).convert('RGB')
+            w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
 

+ 10 - 10
models/line_detect/line_detect.py

@@ -178,7 +178,7 @@ class LineDetect(BaseDetectionNet):
 
         if line_predictor is None and detect_line:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(in_channels=128)
+            line_predictor = LinePredictor(in_channels=256)
 
         if point_head is None and detect_point:
             layers = tuple(num_points for _ in range(8))
@@ -322,7 +322,7 @@ def linedetect_newresnet18fpn(
     if num_points is None:
         num_points = 3
 
-
+    size=1024
     backbone =resnet18fpn()
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -340,7 +340,7 @@ def linedetect_newresnet18fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
     return model
 
@@ -359,8 +359,8 @@ def linedetect_newresnet50fpn(
     if num_points is None:
         num_points = 3
 
-
-    backbone =resnet50fpn()
+    size=768
+    backbone =resnet50fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
     roi_pooler = MultiScaleRoIAlign(
@@ -376,7 +376,7 @@ def linedetect_newresnet50fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
 
 
@@ -398,7 +398,7 @@ def linedetect_newresnet101fpn(
         num_points = 3
 
 
-    backbone =resnet101fpn()
+    backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
     roi_pooler = MultiScaleRoIAlign(
@@ -432,7 +432,7 @@ def linedetect_maxvitfpn(
     if num_points is None:
         num_points = 3
 
-    size=224*2
+    size=224*3
 
     maxvit = MaxVitBackbone(input_size=(size,size))
     # print(maxvit.named_children())
@@ -547,9 +547,9 @@ def linedetect_resnet18_fpn(
         num_classes = 3
     if num_points is None:
         num_points = 3
-
+    size=1024
     backbone = resnet_fpn_backbone(backbone_name='resnet18',weights=None)
-    model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs)
+    model = LineDetect(backbone,min_size=size,max_size=size , num_classes=num_classes, num_points=num_points, **kwargs)
 
     return model
 

+ 1 - 1
models/line_detect/loi_heads.py

@@ -586,7 +586,7 @@ class RoIHeads(nn.Module):
         self.detect_arc =detect_arc
 
         self.channel_compress = nn.Sequential(
-            nn.Conv2d(128, 8, kernel_size=1),
+            nn.Conv2d(256, 8, kernel_size=1),
             nn.BatchNorm2d(8),
             nn.ReLU(inplace=True)
         )

+ 1 - 0
models/line_detect/train.yaml

@@ -1,6 +1,7 @@
 io:
   logdir: train_results
   datadir: /data/share/rlq/datasets/250718caisegangban
+#  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2

+ 3 - 3
models/line_detect/train_demo.py

@@ -18,11 +18,11 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=3)
-    # model=linedetect_newresnet50fpn(num_points=3)
+    model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    model=linedetect_maxvitfpn()
+    # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()
-    model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250718_151833/weights/best_val.pth')
+    # model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250718_153419/weights/best_val.pth')
     # model=linedetect_swin_transformer_fpn(type='t')
     model.start_train(cfg='train.yaml')