RenLiqiang 5 місяців тому
батько
коміт
e212a1fd73

+ 1 - 1
models/base/base_detection_net.py

@@ -92,7 +92,7 @@ class BaseDetectionNet(BaseModel):
             original_image_sizes.append((val[0], val[1]))
 
         images, targets = self.transform(images, targets)
-        # print(f'images shape from transform:{images.tensors.shape }')
+        print(f'images shape from transform:{images.tensors.shape }')
 
         # Check for degenerate boxes
         # TODO: Move this to a function

+ 39 - 27
models/base/transforms.py

@@ -136,35 +136,47 @@ class RandomCrop:
         self.size = size
 
     def __call__(self, img, target):
-        w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
-        th, tw = self.size
+        width, height = F.get_image_size(img)
+        crop_height, crop_width = self.size
 
-        if h <= th or w <= tw:
-            return img, target
+        # 随机选择裁剪区域
+        left = random.randint(0, max(width - crop_width, 0))
+        top = random.randint(0, max(height - crop_height, 0))
+        right = min(left + crop_width, width)
+        bottom = min(top + crop_height, height)
 
-        i = random.randint(0, h - th)
-        j = random.randint(0, w - tw)
+        # 裁剪图像
+        img = F.crop(img, top, left, bottom - top, right - left)
 
-        img = F.crop(img, i, j, th, tw)
+        if "boxes" in target:
+            boxes = target["boxes"]
+            labels = target["labels"] if "labels" in target else None
 
-        # Adjust boxes
-        boxes = target["boxes"]
-        boxes = boxes - torch.tensor([j, i, j, i], device=boxes.device)
-        boxes = torch.clamp(boxes, min=0)
-        xmax, ymax = tw, th
-        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(max=xmax)
-        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(max=ymax)
-        target["boxes"] = boxes
+            # 将bounding boxes转换到裁剪区域坐标系
+            cropped_boxes = boxes.clone()
+            cropped_boxes[:, 0::2] -= left
+            cropped_boxes[:, 1::2] -= top
 
-        # Adjust lines
-        if "lines" in target:
-            lines = target["lines"].clone()
-            lines[..., 0] -= j
-            lines[..., 1] -= i
-            lines = torch.clamp(lines, min=0)
-            lines[..., 0] = torch.clamp(lines[..., 0], max=tw)
-            lines[..., 1] = torch.clamp(lines[..., 1], max=th)
-            target["lines"] = lines
+            # 确保bounding boxes在裁剪区域内
+            cropped_boxes[:, 0::2].clamp_(min=0, max=crop_width)
+            cropped_boxes[:, 1::2].clamp_(min=0, max=crop_height)
+
+            # 计算新的宽高
+            w = cropped_boxes[:, 2] - cropped_boxes[:, 0]
+            h = cropped_boxes[:, 3] - cropped_boxes[:, 1]
+
+            # 过滤掉无效的bounding boxes(宽度或高度为0)
+            valid_boxes_mask = (w > 0) & (h > 0)
+
+            # 更新有效bounding boxes
+            cropped_boxes = cropped_boxes[valid_boxes_mask]
+            if labels is not None:
+                labels = labels[valid_boxes_mask]
+
+            # 更新target
+            target["boxes"] = cropped_boxes
+            if labels is not None:
+                target["labels"] = labels
 
         return img, target
 
@@ -473,14 +485,14 @@ def get_transforms(augmention=True):
         transforms_list.append(RandomGrayscale(0.1))
 
         transforms_list.append(GaussianBlur())
-        transforms_list.append(RandomErasing())
+        # transforms_list.append(RandomErasing())
         transforms_list.append(RandomHorizontalFlip(0.5))
-        transforms_list.append(RandomVerticalFlip(0.2))
+        transforms_list.append(RandomVerticalFlip(0.5))
         # transforms_list.append(RandomPerspective())
         transforms_list.append(RandomRotation(degrees=15))
         transforms_list.append(RandomResize(512, 2048))
 
-        transforms_list.append(RandomCrop((512,512)))
+        # transforms_list.append(RandomCrop((512,512)))
 
     transforms_list.append(DefaultTransform())
 

+ 41 - 12
models/line_detect/line_dataset.py

@@ -42,6 +42,7 @@ class LineDataset(BaseDataset):
         self.target_type = target_type
         self.img_type=img_type
         self.augmentation=augmentation
+        print(f'augmentation:{augmentation}')
         # self.default_transform = DefaultTransform()
 
     def __getitem__(self, index) -> T_co:
@@ -80,15 +81,18 @@ class LineDataset(BaseDataset):
 
         target["image_id"] = torch.tensor(item)
 
-        target["boxes"], lines = get_boxes_lines(objs,shape)
+        target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
         # print(f'lines:{lines}')
-        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
-
+        # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
+        # print(f'target points:{target["points"]}')
 
         a = torch.full((lines.shape[0],), 2).unsqueeze(1)
         lines = torch.cat((lines, a), dim=1)
 
         target["lines"] = lines.to(torch.float32).view(-1,2,3)
+
+
+
         print(f'lines:{target["lines"].shape}')
         target["img_size"]=shape
 
@@ -131,31 +135,56 @@ class LineDataset(BaseDataset):
 
 def get_boxes_lines(objs,shape):
     boxes = []
+    labels=[]
     h,w=shape
     line_point_pairs = []
+    points=[]
 
     for obj in objs:
         # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
         # print(f"points:{obj['points']}")
+        label=obj['label']
+        if label =='line':
+            a,b=obj['points'][0],obj['points'][1]
+
+            line_point_pairs.append(a)
+            line_point_pairs.append(b)
+
+            xmin = max(0, (min(a[0], b[0]) - 6))
+            xmax = min(w, (max(a[0], b[0]) + 6))
+            ymin = max(0, (min(a[1], b[1]) - 6))
+            ymax = min(h, (max(a[1], b[1]) + 6))
+
+            boxes.append([ xmin,ymin,  xmax,ymax])
+            labels.append(torch.tensor(2))
+
+        elif label =='point':
+             p= obj['points'][0]
+             xmin=max(0,p[0]-6)
+             xmax = min(w, p[0] +6)
+             ymin=max(0,p[1]-6)
+             ymax = max(h, p[1] + 6)
+
+             points.append(p)
+             labels.append(torch.tensor(1))
+             boxes.append([xmin, ymin, xmax, ymax])
 
-        a,b=obj['points'][0],obj['points'][1]
 
-        line_point_pairs.append(a)
-        line_point_pairs.append(b)
 
-        xmin = max(0, (min(a[0], b[0]) - 6))
-        xmax = min(w, (max(a[0], b[0]) + 6))
-        ymin = max(0, (min(a[1], b[1]) - 6))
-        ymax = min(h, (max(a[1], b[1]) + 6))
+        elif label =='arc':
 
-        boxes.append([ xmin,ymin,  xmax,ymax])
+            labels.append(torch.tensor(3))
 
     boxes=torch.tensor(boxes)
+    labels=torch.tensor(labels)
+    points=torch.tensor(points)
+    # print(f'read labels:{labels}')
+    # print(f'read points:{points}')
     line_point_pairs=torch.tensor(line_point_pairs)
 
     # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
-    return boxes,line_point_pairs
+    return boxes,line_point_pairs,points,labels
 
 if __name__ == '__main__':
     path=r"\\192.168.50.222/share/rlq/datasets/0706_"

+ 1 - 1
models/line_detect/line_detect.py

@@ -150,7 +150,7 @@ class LineDetect(BaseDetectionNet):
 
 
         if line_head is None:
-            keypoint_layers = tuple(1 for _ in range(8))
+            keypoint_layers = tuple(num_points for _ in range(8))
             line_head = LineHeads(8, keypoint_layers)
 
         # if line_predictor is None:

+ 174 - 42
models/line_detect/loi_heads.py

@@ -187,6 +187,36 @@ def angle_loss_cosine(pred_dir, gt_dir):
         cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
         return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
 
+
+def single_point_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tensor
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
+    # show_heatmap(gs[0],'target')
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
+        # print(f'heatmap:{heatmap.shape}')
+        heatmap = heatmap.unsqueeze(0)
+        x1, y1, x2, y2 = map(int, roi)
+        roi_heatmap = torch.zeros_like(heatmap)
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
+        # show_heatmap(roi_heatmap,'roi_heatmap')
+        all_roi_heatmap.append(roi_heatmap)
+
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
+    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
+
+    return all_roi_heatmap
+
 def line_points_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tensor
     print(f'rois:{rois.shape}')
@@ -278,7 +308,7 @@ def line_points_to_heatmap_(keypoints, rois, heatmap_size):
     return gs_heatmap
 
 
-def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
+def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
     """
     为一组点生成并合并高斯热图。
 
@@ -294,6 +324,7 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
     """
 
     assert xs.shape == ys.shape, "x and y must have the same shape"
+    print(f'xs:{xs.shape}')
     N = xs.shape[0]
     print(f'N:{N}')
 
@@ -307,26 +338,32 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
     # print(f'heatmap_size:{heatmap_size}')
     # 初始化输出热图
     combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
-    for i in range(N):
-        mu_x1 = xs[i, 0].clamp(0, heatmap_size - 1).item()
-        mu_y1 = ys[i, 0].clamp(0, heatmap_size - 1).item()
 
-        # 计算距离平方
-        dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
+    for i in range(N):
+        heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
+        for j in range(num_points):
+            mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
+            mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
 
-        # 计算高斯分布
-        heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
+            # 计算距离平方
+            dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
 
-        mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
-        mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
+            # 计算高斯分布
+            heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
 
-        # 计算距离平方
-        dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
+            heatmap+=heatmap1
 
-        # 计算高斯分布
-        heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
 
-        heatmap = heatmap1 + heatmap2
+        # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
+        # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
+        #
+        # # 计算距离平方
+        # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
+        #
+        # # 计算高斯分布
+        # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
+        #
+        # heatmap = heatmap1 + heatmap2
 
         # 将当前热图累加到结果中
         combined_heatmap[i] = heatmap
@@ -575,21 +612,22 @@ def heatmaps_to_lines(maps, rois):
 
 
 def lines_features_align(features, proposals, img_size):
-    print(f'lines_features_align features:{features.shape}')
+    print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
 
     align_feat_list = []
     for feat, proposals_per_img in zip(features, proposals):
-        # print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
-
-        feat = feat.unsqueeze(0)
-        for proposal in proposals_per_img:
-            align_feat = torch.zeros_like(feat)
-            # print(f'align_feat:{align_feat.shape}')
-            x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
-            # 将每个proposal框内的部分赋值到align_feats对应位置
-            align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
-            align_feat_list.append(align_feat)
-
+        print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
+        if proposals_per_img.shape[0]>0:
+            feat = feat.unsqueeze(0)
+            for proposal in proposals_per_img:
+                align_feat = torch.zeros_like(feat)
+                # print(f'align_feat:{align_feat.shape}')
+                x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
+                # 将每个proposal框内的部分赋值到align_feats对应位置
+                align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
+                align_feat_list.append(align_feat)
+
+    print(f'align_feat_list:{align_feat_list}')
     feats_tensor = torch.cat(align_feat_list)
 
     print(f'align features :{feats_tensor.shape}')
@@ -611,7 +649,8 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     gs_heatmaps = []
     valid = []
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
-        print(f'proposals_per_image:{proposals_per_image.shape}')
+        print(f'line_proposals_per_image:{proposals_per_image.shape}')
+        print(f'gt_lines:{gt_lines}')
         kp = gt_kp_in_image[midx]
         gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
         gs_heatmaps.append(gs_heatmaps_per_img)
@@ -646,6 +685,37 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     return line_loss
 
 
+def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = line_logits.shape
+    len_proposals = len(proposals)
+
+    print(f'starte to compute_point_loss')
+    print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
+    if H != W:
+        raise ValueError(
+            f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+
+    gs_heatmaps = []
+    print(f'point_matched_idxs:{point_matched_idxs}')
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
+        print(f'proposals_per_image:{proposals_per_image.shape}')
+        kp = gt_kp_in_image[midx]
+        # print(f'gt_kp_in_image:{gt_kp_in_image}')
+        gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
+
+    line_logits = line_logits.squeeze(1)
+
+    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+
+    return line_loss
+
 def lines_to_boxes(lines, img_size=511):
     """
     输入:
@@ -1244,6 +1314,7 @@ class RoIHeads(nn.Module):
             image_shapes (List[Tuple[H, W]])
             targets (List[Dict])
         """
+
         print(f'roihead forward!!!')
         if targets is not None:
             for t in targets:
@@ -1301,6 +1372,7 @@ class RoIHeads(nn.Module):
 
         if self.has_line():
             print(f'roi_heads forward has_line()!!!!')
+            print(f'labels:{labels}')
             line_proposals = [p["boxes"] for p in result]
             print(f'boxes_proposals:{len(line_proposals)}')
 
@@ -1313,28 +1385,57 @@ class RoIHeads(nn.Module):
                 num_images = len(proposals)
                 print(f'num_images:{num_images}')
                 line_proposals = []
+                point_proposals = []
+                arc_proposals = []
+
                 pos_matched_idxs = []
+                line_pos_matched_idxs = []
+                point_pos_matched_idxs = []
                 if matched_idxs is None:
                     raise ValueError("if in trainning, matched_idxs should not be None")
 
                 for img_id in range(num_images):
                     pos = torch.where(labels[img_id] > 0)[0]
-                    line_proposals.append(proposals[img_id][pos])
-                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+
+                    line_pos=torch.where(labels[img_id] ==2)[0]
+                    point_pos=torch.where(labels[img_id] ==1)[0]
+
+                    line_proposals.append(proposals[img_id][line_pos])
+                    point_proposals.append(proposals[img_id][point_pos])
+
+                    line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
+                    point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
+
+                    # pos_matched_idxs.append(matched_idxs[img_id][pos])
             else:
                 if targets is not None:
 
                     pos_matched_idxs = []
                     num_images = len(proposals)
                     line_proposals = []
+                    point_proposals=[]
+                    arc_proposals=[]
+
+                    line_pos_matched_idxs = []
+                    point_pos_matched_idxs = []
                     print(f'val num_images:{num_images}')
                     if matched_idxs is None:
                         raise ValueError("if in trainning, matched_idxs should not be None")
 
                     for img_id in range(num_images):
                         pos = torch.where(labels[img_id] > 0)[0]
-                        line_proposals.append(proposals[img_id][pos])
-                        pos_matched_idxs.append(matched_idxs[img_id][pos])
+                        # line_proposals.append(proposals[img_id][pos])
+                        # pos_matched_idxs.append(matched_idxs[img_id][pos])
+
+                        line_pos = torch.where(labels[img_id].item() == 2)[0]
+                        point_pos = torch.where(labels[img_id].item() == 1)[0]
+
+                        line_proposals.append(proposals[img_id][line_pos])
+                        point_proposals.append(proposals[img_id][point_pos])
+
+                        line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
+                        point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
+
                 else:
                     pos_matched_idxs = None
 
@@ -1347,7 +1448,14 @@ class RoIHeads(nn.Module):
             line_features = self.channel_compress(features['0'])
             #(b.8,512,512)
 
-            line_features = lines_features_align(line_features, line_proposals, image_shapes)
+
+            all_proposals=line_proposals+point_proposals
+            # print(f'all_proposals:{all_proposals}')
+            filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
+
+
+            line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
+            print(f'line_features from features_align:{line_features.shape}')
 
             line_features = self.line_head(line_features)
             #(N,1,512,512)
@@ -1359,36 +1467,59 @@ class RoIHeads(nn.Module):
 
             loss_line = {}
             loss_line_iou = {}
-
+            model_loss_point = {}
             if self.training:
 
                 if targets is None or pos_matched_idxs is None:
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
                 gt_lines = [t["lines"] for t in targets]
+                gt_points = [t["points"] for t in targets]
                 print(f'gt_lines:{gt_lines[0].shape}')
                 h, w = targets[0]["img_size"]
                 img_size = h
-                rcnn_loss_line = lines_point_pair_loss(
-                    line_logits, line_proposals, gt_lines, pos_matched_idxs
-                )
-                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
+                # rcnn_loss_line = lines_point_pair_loss(
+                #     line_logits, line_proposals, gt_lines, pos_matched_idxs
+                # )
+                # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
+                gt_lines_tensor=torch.cat(gt_lines)
+                gt_points_tensor = torch.cat(gt_points)
+                print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
+                print(f'gt_points_tensor:{gt_points_tensor.shape}')
+                if gt_lines_tensor.shape[0]>0:
+                    rcnn_loss_line = lines_point_pair_loss(
+                        line_logits, line_proposals, gt_lines, line_pos_matched_idxs
+                    )
+                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+
+                if gt_points_tensor.shape[0]>0:
+                    model_loss_point = compute_point_loss(
+                        line_logits, point_proposals, gt_points, point_pos_matched_idxs
+                    )
 
                 loss_line = {"loss_line": rcnn_loss_line}
                 loss_line_iou = {'loss_line_iou': iou_loss}
+                loss_point = {"loss_point": model_loss_point}
 
             else:
                 if targets is not None:
                     h, w = targets[0]["img_size"]
                     img_size = h
                     gt_lines = [t["lines"] for t in targets]
-                    rcnn_loss_lines = lines_point_pair_loss(
-                        line_logits, line_proposals, gt_lines, pos_matched_idxs
+                    gt_points = [t["points"] for t in targets]
+
+                    rcnn_loss_line = lines_point_pair_loss(
+                        line_logits, line_proposals, gt_lines, line_pos_matched_idxs
+                    )
+                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+
+                    model_loss_point = compute_point_loss(
+                        line_logits, point_proposals, gt_points, point_pos_matched_idxs
                     )
-                    loss_line = {"loss_line": rcnn_loss_lines}
 
-                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
+                    loss_line = {"loss_line": rcnn_loss_line}
                     loss_line_iou = {'loss_line_iou': iou_loss}
+                    loss_point={"loss_point":model_loss_point}
 
 
                 else:
@@ -1405,6 +1536,7 @@ class RoIHeads(nn.Module):
 
             losses.update(loss_line)
             losses.update(loss_line_iou)
+            losses.update(loss_point)
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]

+ 4 - 3
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/rlq/datasets/0706_
+  datadir: \\192.168.50.222/share/rlq/datasets/Dataset0709
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -11,9 +11,10 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 1
+  batch_size: 2
   max_epoch: 80000
-  augmentation: True
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4

+ 1 - 1
models/line_detect/train_demo.py

@@ -16,6 +16,6 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn(num_points=2)
+    model=linedetect_newresnet18fpn(num_points=3)
 
     model.start_train(cfg='train.yaml')

+ 2 - 2
models/line_detect/trainer.py

@@ -245,8 +245,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='val')
+        dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
+        dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=False, data_type=self.data_type, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)