Przeglądaj źródła

添加记录验证机损失功能

RenLiqiang 7 miesięcy temu
rodzic
commit
3b8e1a3d58

+ 7 - 10
models/base/base_detection_net.py

@@ -38,11 +38,15 @@ class BaseDetectionNet(BaseModel):
         self._has_warned = False
 
     @torch.jit.unused
-    def eager_outputs(self, losses, detections):
+    def eager_outputs(self, losses, detections,targets=None):
         # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
         if self.training:
             return  losses
-        return detections
+        else:
+            if targets is not None:
+                return detections,losses
+            else:
+                return detections
 
 
 
@@ -62,13 +66,6 @@ class BaseDetectionNet(BaseModel):
                 like `scores`, `labels` and `mask` (for Mask R-CNN models).
 
         """
-        if targets is not None:
-            self.training = True
-            # print(f'targets is not None')
-
-        else:
-            self.training = False
-            # print(f'targets is None')
 
         if self.training:
             if targets is None:
@@ -134,4 +131,4 @@ class BaseDetectionNet(BaseModel):
                 self._has_warned = True
             return losses, detections
         else:
-            return self.eager_outputs(losses, detections)
+            return self.eager_outputs(losses, detections,targets)

+ 14 - 9
models/line_detect/111.py

@@ -128,17 +128,17 @@ class Trainer(BaseTrainer):
             print(f"No saved model found at {save_path}")
         return model, optimizer
 
-    def writer_loss(self, writer, losses, epoch):
+    def writer_loss(self, writer, losses, epoch,mode='train'):
         try:
             for key, value in losses.items():
                 if key == 'loss_wirepoint':
                     for subdict in losses['loss_wirepoint']['losses']:
                         for subkey, subvalue in subdict.items():
-                            writer.add_scalar(f'loss/{subkey}',
+                            writer.add_scalar(f'{mode}/loss/{subkey}',
                                               subvalue.item() if hasattr(subvalue, 'item') else subvalue,
                                               epoch)
                 elif isinstance(value, torch.Tensor):
-                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+                    writer.add_scalar(f'{mode}/loss/{key}', value.item(), epoch)
         except Exception as e:
             print(f"TensorBoard logging error: {e}")
 
@@ -184,7 +184,8 @@ class Trainer(BaseTrainer):
         last_model_path = os.path.join(wts_path, 'last.pth')
         best_train_model_path = os.path.join(wts_path, 'best_train.pth')
         best_val_model_path = os.path.join(wts_path, 'best_val.pth')
-        global_step = 0
+        global_train_step = 0
+        global_val_step = 0
 
         for epoch in range(kwargs['optim']['max_epoch']):
             print(f"epoch:{epoch}")
@@ -199,8 +200,8 @@ class Trainer(BaseTrainer):
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
-                self.writer_loss(writer, losses, global_step)
-                global_step += 1
+                self.writer_loss(writer, losses, global_train_step)
+                global_train_step += 1
 
 
 
@@ -208,8 +209,9 @@ class Trainer(BaseTrainer):
             print(f'model.eval!!')
             # ========== Validation ==========
             total_val_loss = 0.0
+            batch_idx=0
             with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                for imgs, targets in data_loader_val:
                     t_start = time.time()
                     print(f'start to predict:{t_start}')
 
@@ -217,7 +219,9 @@ class Trainer(BaseTrainer):
                     targets = move_to_device(targets, device)
                     print(f'targets:{targets}')
 
-                    losses = model(imgs, targets)
+                    _,losses = model(imgs, targets)
+                    self.writer_loss(writer, losses, global_val_step,mode='val')
+                    global_val_step+=1
                     print(f'val losses:{losses}')
                     loss = _loss(losses)
                     total_val_loss += loss.item()
@@ -229,7 +233,8 @@ class Trainer(BaseTrainer):
                     print(f'predict used:{t_end - t_start}')
                     if batch_idx == 0:
                         show_line(imgs[0], pred, epoch, writer)
-                    break
+                        batch_idx+=1
+
 
             avg_val_loss = total_val_loss / len(data_loader_val)
             # print(f'avg_val_loss:{avg_val_loss}')

+ 36 - 28
models/line_detect/roi_heads.py

@@ -1000,13 +1000,13 @@ class RoIHeads(nn.Module):
             image_shapes (List[Tuple[H, W]])
             targets (List[Dict])
         """
-        if targets is not None:
-            self.training = True
-            # print(f'targets is not None')
-
-        else:
-            self.training = False
-            # print(f'targets is None')
+        # if targets is not None:
+        #     self.training = True
+        #     # print(f'targets is not None')
+        #
+        # else:
+        #     self.training = False
+        #     # print(f'targets is None')
 
         if targets is not None:
             for t in targets:
@@ -1023,9 +1023,12 @@ class RoIHeads(nn.Module):
         if self.training:
             proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
         else:
-            labels = None
-            regression_targets = None
-            matched_idxs = None
+            if targets is not None:
+                proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+            else:
+                labels = None
+                regression_targets = None
+                matched_idxs = None
 
 
         box_features = self.box_roi_pool(features, proposals, image_shapes)
@@ -1042,17 +1045,20 @@ class RoIHeads(nn.Module):
             loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
             losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
         else:
-
-            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-            num_images = len(boxes)
-            for i in range(num_images):
-                result.append(
-                    {
-                        "boxes": boxes[i],
-                        "labels": labels[i],
-                        "scores": scores[i],
-                    }
-                )
+            if targets is not None:
+                loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+                losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+            else:
+                boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+                num_images = len(boxes)
+                for i in range(num_images):
+                    result.append(
+                        {
+                            "boxes": boxes[i],
+                            "labels": labels[i],
+                            "scores": scores[i],
+                        }
+                    )
 
         line_features = features['0']
         if self.has_line():
@@ -1082,13 +1088,15 @@ class RoIHeads(nn.Module):
 
             else:
                 # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
-
-                print(f'model inference!!!')
-                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
-                result.append(line_features)
-                result.append(pred)
-
-                loss_wirepoint = {}
+                if targets is not None:
+                    rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
+                    loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+                else:
+                    print(f'model inference!!!')
+                    pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                    result.append(line_features)
+                    result.append(pred)
+                    loss_wirepoint = {}
 
             losses.update(loss_wirepoint)
         else: