RenLiqiang 4 місяців тому
батько
коміт
f7d29d626b

+ 43 - 0
models/line_detect/heads/circle_heads.py

@@ -0,0 +1,43 @@
+import torch
+from torch import nn
+
+class CircleHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class CirclePredictor(nn.Module):
+    def __init__(self, in_channels, out_channels=1 ):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            out_channels,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        # print(f'before kps_score_lowres x:{x.shape}')
+        x = self.kps_score_lowres(x)
+        # print(f'kps_score_lowres x:{x.shape}')
+        return x
+        # return torch.nn.functional.interpolate(
+        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        # )

+ 97 - 3
models/line_detect/heads/head_losses.py

@@ -143,6 +143,37 @@ def single_point_to_heatmap(keypoints, rois, heatmap_size):
 
     return all_roi_heatmap
 
+def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
+    # type: (Tensor, Tensor, int) -> Tensor
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0].unsqueeze(1)
+    y = keypoints[..., 1].unsqueeze(1)
+
+
+    gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
+        # show_heatmap(heatmap, 'target')
+        # print(f'heatmap:{heatmap.shape}')
+        heatmap = heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
+        # show_heatmap(roi_heatmap[0],'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+    return all_roi_heatmap
+
 def line_points_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tensor
     print(f'rois:{rois.shape}')
@@ -377,7 +408,7 @@ def non_maximum_suppression(a):
     return a * mask
 
 
-def heatmaps_to_points(maps, rois):
+def heatmaps_to_points(maps, rois,num_points=2):
 
 
     point_preds = torch.zeros((len(rois),  2), dtype=torch.float32, device=maps.device)
@@ -393,7 +424,7 @@ def heatmaps_to_points(maps, rois):
         # roi_map_probs = scores_to_probs(roi_map.copy())
         w = point_roi_map.shape[2]
         flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
-        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
+        point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
         print(f'point index:{point_index}')
         # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
 
@@ -643,6 +674,51 @@ def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
 
     return line_loss
 
+
+def compute_circle_loss(circle_logits, proposals, gt_circles, circle_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = circle_logits.shape
+    len_proposals = len(proposals)
+
+    empty_count = 0
+    non_empty_count = 0
+
+    for prop in proposals:
+        if prop.shape[0] == 0:
+            empty_count += 1
+        else:
+            non_empty_count += 1
+
+    print(f"Empty proposals count: {empty_count}")
+    print(f"Non-empty proposals count: {non_empty_count}")
+
+    print(f'starte to compute_circle_loss')
+    print(f'compute_circle_loss circle_logits.shape:{circle_logits.shape},len_proposals:{len_proposals}')
+    if H != W:
+        raise ValueError(
+            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+
+    gs_heatmaps = []
+    # print(f'point_matched_idxs:{point_matched_idxs}')
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_circles, circle_matched_idxs):
+        print(f'proposals_per_image:{proposals_per_image.shape}')
+        kp = gt_kp_in_image[midx]
+        # print(f'gt_kp_in_image:{gt_kp_in_image}')
+        gs_heatmaps_per_img = points_to_heatmap(kp, proposals_per_image,num_points=4, heatmap_size=discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{circle_logits.squeeze(1).shape}')
+
+    circle_logits = circle_logits[:, 0]
+    print(f'circle_logits:{circle_logits.shape}')
+
+    circle_loss = F.cross_entropy(circle_logits, gs_heatmaps)
+
+    return circle_loss
+
 def lines_to_boxes(lines, img_size=511):
     """
     输入:
@@ -767,13 +843,31 @@ def point_inference(x, point_boxes):
     x2 = x.split(boxes_per_image, dim=0)
 
     for xx, bb in zip(x2, point_boxes):
-        point_prob,point_scores = heatmaps_to_points(xx, bb)
+        point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=1)
 
         points_probs.append(point_prob.unsqueeze(1))
         points_scores.append(point_scores)
 
     return points_probs,points_scores
 
+def circle_inference(x, point_boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+
+    points_probs = []
+    points_scores = []
+
+    boxes_per_image = [box.size(0) for box in point_boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, point_boxes):
+        point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=4)
+
+        points_probs.append(point_prob.unsqueeze(1))
+        points_scores.append(point_scores)
+
+    return points_probs,points_scores
+
+
 def line_inference(x, line_boxes):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
     lines_probs = []

+ 32 - 4
models/line_detect/line_dataset.py

@@ -94,7 +94,7 @@ class LineDataset(BaseDataset):
         target = {}
 
         target["image_id"] = torch.tensor(item)
-        boxes, lines, points, arc_mask,labels = get_boxes_lines(objs, shape)
+        boxes, lines, points, arc_mask,circle_4points,labels = get_boxes_lines(objs, shape)
 
 
         if points is not None:
@@ -111,6 +111,9 @@ class LineDataset(BaseDataset):
         # else:
         #     print(f'not arc_mask dataset')
 
+        if circle_4points is not None:
+            target['circle']=circle_4points
+
         target["boxes"]=boxes
         target["labels"]=labels
         # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
@@ -144,7 +147,9 @@ class LineDataset(BaseDataset):
         if show_type=='all':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                                   colors="yellow", width=1)
-            keypoint_img=draw_keypoints(boxed_image,target['points'].unsqueeze(1),colors='red',width=3)
+            circle=target['circle']
+            print(f'taget circle:{circle.shape}')
+            keypoint_img=draw_keypoints(boxed_image,circle,colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
 
@@ -178,6 +183,8 @@ def get_boxes_lines(objs,shape):
     line_point_pairs = []
     points=[]
     line_mask=[]
+    circle_4points=[]
+
 
     for obj in objs:
         # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
@@ -228,6 +235,21 @@ def get_boxes_lines(objs,shape):
 
             labels.append(torch.tensor(3))
 
+        elif label == 'circle' :
+            circle_4points.append(obj['points'])
+
+            xmin = max(obj['xmin'] - 6, 0)
+
+            xmax = min(obj['xmax'] + 6, w)
+
+            ymin = max(obj['ymin'] - 6, 0)
+
+            ymax = min(obj['ymax'] + 6, h)
+
+            boxes.append([xmin, ymin, xmax, ymax])
+
+            labels.append(torch.tensor(3))
+
     boxes=torch.tensor(boxes)
     print(f'boxes:{boxes.shape}')
     labels=torch.tensor(labels)
@@ -250,9 +272,15 @@ def get_boxes_lines(objs,shape):
     else:
         line_mask=torch.tensor(line_mask,dtype=torch.float32)
         print(f'arc_mask shape :{line_mask.shape},{line_mask.dtype}')
-    return boxes,line_point_pairs,points,line_mask, labels
+
+    if len(circle_4points)==0:
+        circle_4points=None
+    else:
+        circle_4points=torch.tensor(circle_4points,dtype=torch.float32)
+
+    return boxes,line_point_pairs,points,line_mask,circle_4points, labels
 
 if __name__ == '__main__':
-    path=r"\\192.168.50.222\share\rlq\datasets\Dataset0709_"
+    path=r"\\192.168.50.222/share/zyh/data/rgb_4point/a_dataset"
     dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
     dataset.show(1,show_type='all')

+ 25 - 13
models/line_detect/line_detect.py

@@ -23,7 +23,8 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 from .heads.arc_heads import ArcHeads, ArcPredictor
-from .heads.decoder import FPNDecoder
+from .heads.arc_unet import ArcUnet
+from .heads.circle_heads import CircleHeads, CirclePredictor
 from .heads.line_heads import LinePredictor
 from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
@@ -99,7 +100,11 @@ class LineDetect(BaseDetectionNet):
             point_head=None,
             point_predictor=None,
 
-            # arc parameters
+            circle_head=None,
+            circle_predictor=None,
+            circle_roi_pool=None,
+
+    # arc parameters
             arc_roi_pool=None,
             arc_head=None,
             arc_predictor=None,
@@ -107,6 +112,7 @@ class LineDetect(BaseDetectionNet):
             detect_point=False,
             detect_line=False,
             detect_arc=True,
+            detect_circle=False,
             **kwargs,
 
     ):
@@ -162,6 +168,7 @@ class LineDetect(BaseDetectionNet):
             detect_point=detect_point,
             detect_line=detect_line,
             detect_arc=detect_arc,
+            detect_circle=detect_circle,
         )
 
         if image_mean is None:
@@ -181,7 +188,6 @@ class LineDetect(BaseDetectionNet):
         if line_predictor is None and detect_line:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
             line_predictor = LinePredictor(in_channels=256)
-            # line_predictor = ArcUnet(Bottleneck)
 
         if point_head is None and detect_point:
             layers = tuple(num_points for _ in range(8))
@@ -197,7 +203,15 @@ class LineDetect(BaseDetectionNet):
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
-            arc_predictor=FPNDecoder(Bottleneck)
+            arc_predictor=ArcUnet(Bottleneck)
+
+        if detect_circle and circle_head is None:
+            layers = tuple(num_points for _ in range(8))
+            circle_head = CircleHeads(8, layers)
+        if detect_circle and circle_predictor is None:
+            layers = tuple(num_points for _ in range(8))
+            # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
+            circle_predictor = CirclePredictor(in_channels=256)
 
 
 
@@ -213,6 +227,10 @@ class LineDetect(BaseDetectionNet):
         self.roi_heads.arc_head = arc_head
         self.roi_heads.arc_predictor = arc_predictor
 
+        self.roi_heads.circle_roi_pool = circle_roi_pool
+        self.roi_heads.circle_head = circle_head
+        self.roi_heads.circle_predictor = circle_predictor
+
     def start_train(self, cfg):
         # cfg = read_yaml(cfg)
         self.trainer = Trainer()
@@ -400,9 +418,10 @@ def linedetect_newresnet50fpn(
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
     model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
-                       detect_point=True,
+                       detect_point=False,
                        detect_line=False,
                        detect_arc=False,
+                       detect_circle=True,
                        **kwargs)
 
 
@@ -479,14 +498,7 @@ def linedetect_newresnet152fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes,
-                       min_size=size,max_size=size,
-                       num_points=num_points, rpn_anchor_generator=anchor_generator,
-                       box_roi_pool=roi_pooler,
-                       detect_point=False,
-                       detect_line=True,
-                       detect_arc=False,
-                       **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
     return model
 

+ 158 - 2
models/line_detect/loi_heads.py

@@ -13,7 +13,8 @@ import libs.vision_libs.models.detection._utils as det_utils
 from collections import OrderedDict
 
 from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
-    lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference
+    lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
+    circle_inference
 
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
@@ -531,6 +532,10 @@ class RoIHeads(nn.Module):
             point_head=None,
             point_predictor=None,
 
+            circle_head=None,
+            circle_predictor=None,
+            circle_roi_pool=None,
+
             # arc parameters
             arc_roi_pool=None,
             arc_head=None,
@@ -545,8 +550,9 @@ class RoIHeads(nn.Module):
             keypoint_predictor=None,
 
             detect_point=True,
-            detect_line=True,
+            detect_line=False,
             detect_arc=False,
+            detect_circle=False,
     ):
         super().__init__()
 
@@ -580,6 +586,10 @@ class RoIHeads(nn.Module):
         self.arc_head = arc_head
         self.arc_predictor = arc_predictor
 
+        self.circle_roi_pool = circle_roi_pool
+        self.circle_head = circle_head
+        self.circle_predictor = circle_predictor
+
 
 
         self.mask_roi_pool = mask_roi_pool
@@ -593,6 +603,7 @@ class RoIHeads(nn.Module):
         self.detect_point =detect_point
         self.detect_line =detect_line
         self.detect_arc =detect_arc
+        self.detect_circle=detect_circle
 
         self.channel_compress = nn.Sequential(
             nn.Conv2d(256, 8, kernel_size=1),
@@ -645,6 +656,15 @@ class RoIHeads(nn.Module):
         #     return False
         return True
 
+    def has_circle(self):
+        # if self.line_roi_pool is None:
+        #     return False
+        if self.circle_head is None:
+            return False
+        # if self.line_predictor is None:
+        #     return False
+        return True
+
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
         # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
         matched_idxs = []
@@ -1284,7 +1304,120 @@ class RoIHeads(nn.Module):
             losses.update(loss_arc)
             print(f'losses:{losses}')
 
+        if self.has_circle and self.detect_circle:
+            print(f'roi_heads forward has_circle()!!!!')
+            # print(f'labels:{labels}')
+            circle_proposals = [p["boxes"] for p in result]
+            print(f'boxes_proposals:{len(circle_proposals)}')
+
+            # if line_proposals is None or len(line_proposals) == 0:
+            #     # 返回空特征或者跳过该部分计算
+            #     return torch.empty(0, C, H, W).to(features['0'].device)
+
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                print(f'num_images:{num_images}')
+                circle_proposals = []
+                circle_pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+                for img_id in range(num_images):
+                    circle_pos = torch.where(labels[img_id] == 1)[0]
+                    circle_proposals.append(proposals[img_id][circle_pos])
+                    circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
+            else:
+                if targets is not None:
+
+                    num_images = len(proposals)
+                    circle_proposals = []
+
+                    circle_pos_matched_idxs = []
+                    print(f'val num_images:{num_images}')
+                    if matched_idxs is None:
+                        raise ValueError("if in trainning, matched_idxs should not be None")
+
+                    for img_id in range(num_images):
+                        circle_pos = torch.where(labels[img_id] == 1)[0]
+                        circle_proposals.append(proposals[img_id][circle_pos])
+                        circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
+
+                else:
+                    pos_matched_idxs = None
+
+            feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
+
+            loss_circle = None
+
+            if self.training:
+
+                if targets is None or circle_pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_circles = [t["circle"] for t in targets if "circle" in t]
+
+                print(f'gt_circle:{gt_circles[0].shape}')
+                h, w = targets[0]["img_size"]
+                img_size = h
+
+                gt_circles_tensor = torch.zeros(0, 0)
+                if len(gt_circles) > 0:
+                    gt_circles_tensor = torch.cat(gt_circles)
+                    print(f'gt_circles_tensor:{gt_circles_tensor.shape}')
 
+                if gt_circles_tensor.shape[0] > 0:
+                    print(f'start to compute circle_loss')
+
+                    loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+
+                if loss_circle is None:
+                    print(f'loss_circle is None111')
+                    loss_circle = torch.tensor(0.0, device=device)
+
+                loss_point = {"loss_circle": loss_circle}
+
+            else:
+                if targets is not None:
+                    h, w = targets[0]["img_size"]
+                    img_size = h
+                    gt_circles = [t["circle"] for t in targets if "circle" in t]
+
+                    gt_circles_tensor = torch.zeros(0, 0)
+                    if len(gt_circles) > 0:
+                        gt_circles_tensor = torch.cat(gt_circles)
+                        print(f'gt_circles_tensor:{gt_circles_tensor.shape}')
+
+                    if gt_circles_tensor.shape[0] > 0:
+                        print(f'start to compute circle_loss')
+
+                        loss_circle = compute_circle_loss(feature_logits, point_proposals, gt_circles,
+                                                        circle_pos_matched_idxs)
+
+                    if loss_circle is None:
+                        print(f'loss_circle is None111')
+                        loss_circle = torch.tensor(0.0, device=device)
+
+                    loss_circle = {"loss_circle": loss_circle}
+
+
+
+                else:
+                    loss_circle = {}
+                    if feature_logits is None or circle_proposals is None:
+                        raise ValueError(
+                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                        )
+
+                    if feature_logits is not None:
+
+                        circles_probs, circles_scores = circle_inference(feature_logits, point_proposals)
+                        for keypoint_prob, kps, r in zip(circles_probs, circles_scores, result):
+                            r["circles"] = keypoint_prob
+                            r["circles_scores"] = kps
+
+            print(f'loss_circle:{loss_circle}')
+            losses.update(loss_circle)
+            print(f'losses:{losses}')
 
 
         if self.has_mask():
@@ -1465,6 +1598,29 @@ class RoIHeads(nn.Module):
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
 
+    def circle_forward1(self, features, image_shapes, proposals):
+        print(f'point_proposals:{len(proposals)}')
+        # cs_features= features['0']
+        # print(f'features-0:{features['0'].shape}')
+        # cs_features = self.channel_compress(features['0'])
+        cs_features=features['0']
+        # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
+        #
+        # if len(filtered_proposals) > 0:
+        #     filtered_proposals_tensor = torch.cat(filtered_proposals)
+        #     print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+        #     proposals=filtered_proposals
+        # point_proposals_tensor = torch.cat(proposals)
+        # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
+
+        feature_logits = self.circle_predictor(cs_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+
+        roi_features = features_align(feature_logits, proposals, image_shapes)
+        if roi_features is not None:
+            print(f'roi_features from align:{roi_features.shape}')
+        return roi_features
+
 
     def arc_forward1(self, features, image_shapes, proposals):
         print(f'arc_proposals:{len(proposals)}')

+ 5 - 5
models/line_detect/train.yaml

@@ -1,13 +1,13 @@
 io:
   logdir: train_results
 #  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
-#  datadir: /data/share/zyh/arc/a_dataset
+#  datadir: /data/share/zyh/arc/a_datasetb
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
-  datadir: /data/share/rlq/datasets/250718caisegangban_hunhe
+#  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-#  datadir: \\192.168.50.222/share/rlq/datasets/arc_datasets_100
+  datadir: \\192.168.50.222/share/zyh/data/rgb_4point/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb
@@ -22,8 +22,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-  augmentation: True
-#  augmentation: False
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 3 - 3
models/line_detect/train_demo.py

@@ -18,10 +18,10 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=3)
-    # model=linedetect_newresnet50fpn(num_points=3)
+    model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
-    model = linedetect_newresnet152fpn(num_points=3)
-    # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250728_080143/weights/best_val.pth')
+    # model = linedetect_newresnet152fpn(num_points=3)
+    # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()