Переглянути джерело

train arc_unet on 4080 ,有初步较好效果

lstrlq 4 місяців тому
батько
коміт
719f35fc8a

+ 3 - 3
models/line_detect/loi_heads.py

@@ -1268,16 +1268,16 @@ class RoIHeads(nn.Module):
                         # )
 
                         print(f'error :both arc_feature_logits and arc_proposals should not be None when not in training mode"')
-                        return None
+                        pass
 
-                    if feature_logits is not None:
+                    if feature_logits is not None and arc_proposals is not None:
 
                         arcs_probs, arcs_scores, arcs_point = arc_inference(feature_logits,arc_proposals, th=0)
                         for keypoint_prob, kps, kp, r in zip(arcs_probs, arcs_scores, arcs_point, result):
                             # r["arcs"] = keypoint_prob
                             r["arcs"] = feature_logits
                             r["arcs_scores"] = kps
-                            r["arcs_point"] = kp
+                            r["arcs_point"] = feature_logits
 
 
             # print(f'loss_point:{loss_point}')

+ 2 - 2
models/line_detect/train.yaml

@@ -1,13 +1,13 @@
 io:
   logdir: train_results
 #  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
-#  datadir: /data/share/zyh/arc/a_datasetb
+  datadir: /data/share/zyh/arc/a_dataset
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: \\192.168.50.222/share/rlq/datasets/arc_datasets_100
+#  datadir: \\192.168.50.222/share/rlq/datasets/arc_datasets_100
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 2 - 2
models/line_detect/train_demo.py

@@ -17,10 +17,10 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn(num_points=3)
+    # model=linedetect_newresnet18fpn(num_points=3)
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
-    # model = linedetect_newresnet152fpn(num_points=3)
+    model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()

+ 0 - 20
models/line_detect/trainer.py

@@ -232,26 +232,6 @@ class Trainer(BaseTrainer):
             # img_tensor = np.transpose(img_tensor)
             self.writer.add_image('z-out-arc', arcs, global_step=epoch)
 
-            aa = result['arcs_point'][0]
-
-            x_coords = aa[:, 0].cpu()/800*2000
-            y_coords = aa[:, 1].cpu()/800*2000
-
-            plt.figure(figsize=(10, 8))
-            plt.imshow(im)
-            plt.scatter(x_coords, y_coords, c='red', s=0.3, label='Arc Points')
-            plt.title("Image with Arc Points")
-            plt.legend()
-            plt.axis('off')
-
-            fig = plt.gcf()
-            fig.canvas.draw()
-            image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
-            image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))  # H x W x 3
-            plt.close()
-
-            self.writer.add_image('z-out-result', image_from_plot, dataformats='HWC')
-
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)