Procházet zdrojové kódy

解决keypoint损失问题

lstrlq před 5 měsíci
rodič
revize
d1cd798b1d

+ 1 - 1
models/ins/train.yaml

@@ -7,7 +7,7 @@ num_classes: 5
 opt: 'adamw'
 batch_size: 2
 epochs: 10
-lr: 0.005
+lr: 0.0005
 momentum: 0.9
 weight_decay: 0.0001
 lr_step_size: 3

+ 2 - 1
models/keypoint/kepointrcnn.py

@@ -29,7 +29,8 @@ class KeypointRCNNModel(nn.Module):
         self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
                                                                               num_keypoints=num_keypoints,
                                                                               progress=False)
-        
+
+
         if transforms is None:
             self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
         # if num_classes != 0:

+ 2 - 0
models/keypoint/keypoint_dataset.py

@@ -125,11 +125,13 @@ class KeypointDataset(BaseDataset):
         # print(f'labels:{target["labels"]}')
         # target["boxes"] = line_boxes(target)
         target["boxes"], keypoints = line_boxes(target)
+        keypoints=keypoints/512
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
         # keypoints= wire_labels["junc_coords"]
         a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
         keypoints = torch.cat((keypoints, a), dim=1)
+
         target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
         # print(f'boxes:{target["boxes"].shape}')
         # 在 __getitem__ 方法中调用此函数

+ 2 - 2
models/keypoint/train.yaml

@@ -6,9 +6,9 @@ dataset_path: /home/admin/tmp/wirenet_1000
 num_classes: 2
 num_keypoints: 2
 opt: 'adamw'
-batch_size: 4
+batch_size: 8
 epochs: 50000
-lr: 0.005
+lr: 0.0002
 momentum: 0.9
 weight_decay: 0.0001
 lr_step_size: 3

+ 14 - 12
models/keypoint/trainer.py

@@ -32,7 +32,7 @@ def log_losses_to_tensorboard(writer, result, step):
     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
 
 
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, scaler=None):
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
     model.train()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
@@ -46,7 +46,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, wr
         lr_scheduler = torch.optim.lr_scheduler.LinearLR(
             optimizer, start_factor=warmup_factor, total_iters=warmup_iters
         )
-
+    total_train_loss=0
     for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
 
         global_step = epoch * len(data_loader) + batch_idx
@@ -107,12 +107,13 @@ def show_line(img, pred, epoch, writer):
 
     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
                                       colors="yellow", width=1)
+
+    # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
+    # plt.show()
+
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
-    print(f'box:{pred["boxes"][:5,:]}')
-    print(f'line:{pred["keypoints"][:5,:]}')
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    # H = pred[1]['wires']
     lines = pred["keypoints"].detach().cpu().numpy()
     scores = pred["keypoints_scores"].detach().cpu().numpy()
     # for i in range(1, len(lines)):
@@ -133,9 +134,12 @@ def show_line(img, pred, epoch, writer):
         for (a, b), s in zip(nlines, nscores):
             if s < t:
                 continue
-            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-            plt.scatter(a[1], a[0], **PLTOPTS)
-            plt.scatter(b[1], b[0], **PLTOPTS)
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            # plt.scatter(a[1], a[0], **PLTOPTS)
+            # plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.plot([a[0], b[0]], [a[1], b[1]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[0], a[1], **PLTOPTS)
+            plt.scatter(b[0], b[1], **PLTOPTS)
         plt.gca().xaxis.set_major_locator(plt.NullLocator())
         plt.gca().yaxis.set_major_locator(plt.NullLocator())
         plt.imshow(im.cpu())
@@ -150,6 +154,7 @@ def show_line(img, pred, epoch, writer):
         writer.add_image("output", img2, epoch)
 
 
+
 def _get_iou_types(model):
     model_without_ddp = model
     if isinstance(model, torch.nn.parallel.DistributedDataParallel):
@@ -306,11 +311,8 @@ def train(model, **kwargs):
         os.mkdir(wts_path)
         os.mkdir(img_results_path)
 
-
-    total_train_loss = 0.0
-
     for epoch in range(epochs):
-        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
+        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):