Sfoglia il codice sorgente

ins mask is may ok

Your Name 1 settimana fa
parent
commit
049567712b

+ 10 - 4
libs/vision_libs/models/detection/transform.py

@@ -218,10 +218,10 @@ class GeneralizedRCNNTransform(nn.Module):
             points = resize_keypoints(points, (h, w), image.shape[-2:])
             target["points"] = points
 
-        if "circle_masks" in target:
-            arc_mask = target["circle_masks"]
-            arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
-            target["circle_masks"] = arc_mask
+        # if "circle_masks" in target:
+        #     arc_mask = target["circle_masks"]
+        #     arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
+        #     target["circle_masks"] = arc_mask
 
         if "circles" in target:
             arc_mask = target["circles"]
@@ -231,12 +231,16 @@ class GeneralizedRCNNTransform(nn.Module):
 
         if "mask_ends" in target:
             arc_mask = target["mask_ends"]
+            # print(f'arc_mask in 1:{arc_mask}')
             arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
+            # print(f'arc_mask in 2:{arc_mask}')
             target["mask_ends"] = arc_mask
         if "mask_params" in target:
             arc_mask = target["mask_params"]
+            # print(f'arc_mask in 3:{arc_mask}')
             arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
             arc_mask[:,2:4] = resize_keypoints(arc_mask[:,2:4], (h, w), image.shape[-2:])
+            # print(f'arc_mask in 4:{arc_mask}')
             target["mask_params"] = arc_mask
         return image, target
 
@@ -333,10 +337,12 @@ class GeneralizedRCNNTransform(nn.Module):
 
             if "arcs" in pred:
                 arc_mask = pred["arcs"]
+                # print(f'arcs in 1:{arc_mask}')
                 arc_mask[:,0:2] = resize_keypoints(arc_mask[:,0:2],im_s, o_im_s)
                 arc_mask[:,2:4] = resize_keypoints(arc_mask[:,2:4], im_s, o_im_s)
                 arc_mask[:,5:7] = resize_keypoints(arc_mask[:,5:7], im_s, o_im_s)
                 arc_mask[:,7:9] = resize_keypoints(arc_mask[:,7:9], im_s, o_im_s)
+                # print(f'arcs in 2:{arc_mask}')
                 result[i]["arcs"] = arc_mask
 
         return result

+ 1 - 1
models/line_detect/heads/arc/arc_heads.py

@@ -121,7 +121,7 @@ class ArcEquationHead(nn.Module):
         # Last two values are auxiliary points
         # Now mapped to the same spatial range as image
         # ------------------------------------------------
-        arc_params[..., 5] = 7   # x auxiliary
+        arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W   # x auxiliary
         arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H   # y auxiliary
         arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W  # x auxiliary
         arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H  # y auxiliary

+ 5 - 36
models/line_detect/line_dataset.py

@@ -101,14 +101,6 @@ class LineDataset(BaseDataset):
         #boxes, line_point_pairs, points, labels, mask_ends, mask_params
         boxes, lines, points, labels, arc_ends, arc_params = get_boxes_lines(objs, shape)
 
-        # print_params(arc_ends, arc_params)
-
-        # if points is not None:
-            # target["points"] = points
-        # if lines is not None:
-        #     a = torch.full((lines.shape[0],), 2).unsqueeze(1)
-        #     lines = torch.cat((lines, a), dim=1)
-        #     target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
         if lines is not None:
             label_3d = labels.view(-1, 1, 1).expand(-1, 2, -1)  # [N] -> [N,2,1]
             line1 = torch.cat([lines, label_3d], dim=-1)  # [N,2,3]
@@ -118,40 +110,19 @@ class LineDataset(BaseDataset):
             target['mask_ends'] = arc_ends
         if arc_params is not None:
             target['mask_params'] = arc_params
-
-
-
             # arc_angles = compute_arc_angles(arc_ends, arc_params)
 
-
-            # print_params(arc_ends,arc_params)
             arc_masks = []
             for i in range(len(arc_params)):
                 mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000),debug=False)
                 arc_masks.append(mask)
-            # print_params(arc_masks)
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
 
-            # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks",
-            #                "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
-            #                force_save=False,image=image,show_on_image=True)
-
-
-
+        print(f'target[circle_masks]:{target["circle_masks"].shape}')
 
 
         target["boxes"] = boxes
         target["labels"] = labels
-        # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
-        # print(f'lines:{lines}')
-        # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
-        # print(f'target points:{target["points"]}')
-
-        # target["lines"] = lines.to(torch.float32).view(-1,2,3)
-
-        # print(f'')
-
-        # print(f'lines:{target["lines"].shape}')
         target["img_size"] = shape
 
         # validate_keypoints(lines, shape[0], shape[1])
@@ -199,14 +170,12 @@ class LineDataset(BaseDataset):
             plt.show()
 
         if show_type == 'circle_masks':
-            boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
-                                              colors="yellow", width=1)
-            # arc = target['arc']
             arc_mask = target['circle_masks']
             # print(f'taget circle:{arc.shape}')
             print(f'target circle_masks:{arc_mask.shape}')
             combined = torch.cat(list(arc_mask), dim=1)
-            plt.imshow(combined)
+            print(f'combine:{combined.shape}')
+            plt.imshow(arc_mask[-1])
             plt.show()
 
         if show_type == 'circle_masks11':
@@ -682,6 +651,6 @@ def get_boxes_lines(objs, shape):
 
 
 if __name__ == '__main__':
-    path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
+    path = r'/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/a_dataset'
     dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
-    dataset.show(19, show_type='arc_yuan_point_ellipse')
+    dataset.show(19, show_type='circle_masks')

+ 40 - 40
models/line_detect/loi_heads.py

@@ -1433,15 +1433,15 @@ class RoIHeads(nn.Module):
                         print(f'start to compute circle_loss')
 
                         loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
-                        total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
-                        loss_arc_ends = loss_arc_ends
-                    if loss_arc_equation is None:
-                        print(f'loss_arc_equation is None')
-                        loss_arc_equation = torch.tensor(0.0, device=device)
-
-                    if loss_arc_ends is None:
-                        print(f'loss_arc_ends is None')
-                        loss_arc_ends = torch.tensor(0.0, device=device)
+                        # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
+                        # loss_arc_ends = loss_arc_ends
+                    # if loss_arc_equation is None:
+                    #     print(f'loss_arc_equation is None')
+                    #     loss_arc_equation = torch.tensor(0.0, device=device)
+                    #
+                    # if loss_arc_ends is None:
+                    #     print(f'loss_arc_ends is None')
+                    #     loss_arc_ends = torch.tensor(0.0, device=device)
 
                     if loss_ins is None:
                         print(f'loss_ins is None111')
@@ -1452,9 +1452,9 @@ class RoIHeads(nn.Module):
                         loss_ins_extra = torch.tensor(0.0, device=device)
 
                     loss_ins = {"loss_ins": loss_ins}
-                    loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
-                    loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
-                    loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
+                    # loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
+                    # loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
+                    # loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
 
 
                 else:
@@ -1475,48 +1475,48 @@ class RoIHeads(nn.Module):
 
                             loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,
                                                            ins_pos_matched_idxs)
-                            total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
-
-                            loss_arc_ends = loss_arc_ends
+                            # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
+                            #
+                            # loss_arc_ends = loss_arc_ends
 
                             # loss_ins_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
 
-                        if loss_ins is None:
-                            print(f'loss_ins is None111')
-                            loss_ins = torch.tensor(0.0, device=device)
-
-                        if loss_ins_extra is None:
-                            print(f'loss_ins_extra is None111')
-                            loss_ins_extra = torch.tensor(0.0, device=device)
+                        # if loss_ins is None:
+                        #     print(f'loss_ins is None111')
+                        #     loss_ins = torch.tensor(0.0, device=device)
 
-                        if loss_arc_equation is None:
-                            print(f'loss_arc_equation is None')
-                            loss_arc_equation = torch.tensor(0.0, device=device)
+                        # if loss_ins_extra is None:
+                        #     print(f'loss_ins_extra is None111')
+                        #     loss_ins_extra = torch.tensor(0.0, device=device)
 
-                        if loss_arc_ends is None:
-                            print(f'loss_arc_ends is None')
-                            loss_arc_ends = torch.tensor(0.0, device=device)
+                        # if loss_arc_equation is None:
+                        #     print(f'loss_arc_equation is None')
+                        #     loss_arc_equation = torch.tensor(0.0, device=device)
+                        #
+                        # if loss_arc_ends is None:
+                        #     print(f'loss_arc_ends is None')
+                        #     loss_arc_ends = torch.tensor(0.0, device=device)
 
                         if loss_ins is None:
                             print(f'loss_ins is None111')
                             loss_ins = torch.tensor(0.0, device=device)
 
-                        if loss_ins_extra is None:
-                            print(f'loss_ins_extra is None111')
-                            loss_ins_extra = torch.tensor(0.0, device=device)
+                        # if loss_ins_extra is None:
+                        #     print(f'loss_ins_extra is None111')
+                        #     loss_ins_extra = torch.tensor(0.0, device=device)
 
                         loss_ins = {"loss_ins": loss_ins}
-                        loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
-                        loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
-                        loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
+                        # loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
+                        # loss_arc_equation = {"loss_arc_equation": loss_arc_equation}
+                        # loss_arc_ends = {"loss_arc_ends": loss_arc_ends}
 
 
 
                     else:
                         loss_ins = {}
-                        loss_ins_extra = {}
-                        loss_arc_equation = {}
-                        loss_arc_ends = {}
+                        # loss_ins_extra = {}
+                        # loss_arc_equation = {}
+                        # loss_arc_ends = {}
                         if feature_logits is None or ins_proposals is None:
                             raise ValueError(
                                 "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
@@ -1553,9 +1553,9 @@ class RoIHeads(nn.Module):
                 print(f'loss_ins:{loss_ins}')
                 print(f'loss_ins_extra:{loss_ins_extra}')
                 losses.update(loss_ins)
-                losses.update(loss_ins_extra)
-                losses.update(loss_arc_equation)
-                losses.update(loss_arc_ends)
+                # losses.update(loss_ins_extra)
+                # losses.update(loss_arc_equation)
+                # losses.update(loss_arc_ends)
                 print(f'losses:{losses}')