Browse Source

degbug arc on 4080

lstrlq 4 tháng trước cách đây
mục cha
commit
4148bf71da

+ 13 - 7
models/line_detect/heads/head_losses.py

@@ -518,16 +518,21 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         print(f'proposals_per_image:{proposals_per_image.shape}')
         kp = gt_kp_in_image[midx]
         # print(f'gt_kp_in_image:{gt_kp_in_image}')
-        gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
-        gs_heatmaps.append(gs_heatmaps_per_img)
+        if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
+            gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
+            gs_heatmaps.append(gs_heatmaps_per_img)
 
-    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
-    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
+    if len(gs_heatmaps)>0:
+        gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
+        print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
 
-    line_logits = feature_logits[:, 0]
-    print(f'single_point_logits:{line_logits.shape}')
+        line_logits = feature_logits[:, 0]
+        print(f'single_point_logits:{line_logits.shape}')
 
-    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+        line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+
+    else:
+        line_loss=100
 
     return line_loss
 
@@ -760,6 +765,7 @@ def arc_inference(x, arc_boxes):
     print(f'arc_boxes:{len(arc_boxes)}')
 
     boxes_per_image = [box.size(0) for box in arc_boxes]
+    print(f'arc boxes_per_image:{boxes_per_image}')
     x2 = x.split(boxes_per_image, dim=0)
 
     for xx, bb in zip(x2, arc_boxes):

+ 2 - 1
models/line_detect/loi_heads.py

@@ -1271,7 +1271,8 @@ class RoIHeads(nn.Module):
 
                         arcs_probs, arcs_scores = arc_inference(feature_logits,arc_proposals)
                         for keypoint_prob, kps, r in zip(arcs_probs, arcs_scores, result):
-                            r["arcs"] = keypoint_prob
+                            # r["arcs"] = keypoint_prob
+                            r["arcs"] = feature_logits
                             r["arcs_scores"] = kps
 
             # print(f'loss_point:{loss_point}')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
 
-  datadir: /data/share/rlq/datasets/arc_datasets
+  datadir: /data/share/zyh/arc/a_datasetb
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban

+ 6 - 15
models/line_detect/trainer.py

@@ -224,25 +224,16 @@ class Trainer(BaseTrainer):
             self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
         if 'arcs' in result:
-            arcs = result['arcs']
-            img_np = img.numpy()
-            img_np=img_np.transpose(1,2,0)
-            # cv2.imshow('original', img_np*255)
-            # cv2.waitKey(100000)
-            for arc in arcs:
-                print(f'arc len:{len(arc)}')
-                for i in range(1, len(arc)):
-                    pt1 = (int(arc[i - 1][0]), int(arc[i - 1][1]))
-                    pt2 = (int(arc[i][0]), int(arc[i][1]))
-                    cv2.line(img_np, pt1, pt2, color=(255, 0, 0), thickness=2)
+            arcs = result['arcs'][0]
 
-            img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
 
+            # img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
 
 
-            img_tensor =torch.tensor(img_rgb)
-            img_tensor = np.transpose(img_tensor)
-            self.writer.add_image('z-out-arc', img_tensor, global_step=epoch)
+
+            # img_tensor =torch.tensor(img_rgb)
+            # img_tensor = np.transpose(img_tensor)
+            self.writer.add_image('z-out-arc', arcs, global_step=epoch)
 
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)