|
|
@@ -1000,13 +1000,13 @@ class RoIHeads(nn.Module):
|
|
|
image_shapes (List[Tuple[H, W]])
|
|
|
targets (List[Dict])
|
|
|
"""
|
|
|
- if targets is not None:
|
|
|
- self.training = True
|
|
|
- # print(f'targets is not None')
|
|
|
-
|
|
|
- else:
|
|
|
- self.training = False
|
|
|
- # print(f'targets is None')
|
|
|
+ # if targets is not None:
|
|
|
+ # self.training = True
|
|
|
+ # # print(f'targets is not None')
|
|
|
+ #
|
|
|
+ # else:
|
|
|
+ # self.training = False
|
|
|
+ # # print(f'targets is None')
|
|
|
|
|
|
if targets is not None:
|
|
|
for t in targets:
|
|
|
@@ -1023,9 +1023,12 @@ class RoIHeads(nn.Module):
|
|
|
if self.training:
|
|
|
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
|
else:
|
|
|
- labels = None
|
|
|
- regression_targets = None
|
|
|
- matched_idxs = None
|
|
|
+ if targets is not None:
|
|
|
+ proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
|
+ else:
|
|
|
+ labels = None
|
|
|
+ regression_targets = None
|
|
|
+ matched_idxs = None
|
|
|
|
|
|
|
|
|
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
|
|
@@ -1042,17 +1045,20 @@ class RoIHeads(nn.Module):
|
|
|
loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
|
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
|
else:
|
|
|
-
|
|
|
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
|
|
|
- num_images = len(boxes)
|
|
|
- for i in range(num_images):
|
|
|
- result.append(
|
|
|
- {
|
|
|
- "boxes": boxes[i],
|
|
|
- "labels": labels[i],
|
|
|
- "scores": scores[i],
|
|
|
- }
|
|
|
- )
|
|
|
+ if targets is not None:
|
|
|
+ loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
|
+ losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
|
+ else:
|
|
|
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
|
|
|
+ num_images = len(boxes)
|
|
|
+ for i in range(num_images):
|
|
|
+ result.append(
|
|
|
+ {
|
|
|
+ "boxes": boxes[i],
|
|
|
+ "labels": labels[i],
|
|
|
+ "scores": scores[i],
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
line_features = features['0']
|
|
|
if self.has_line():
|
|
|
@@ -1082,13 +1088,15 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
else:
|
|
|
# rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
|
|
|
-
|
|
|
- print(f'model inference!!!')
|
|
|
- pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
|
|
|
- result.append(line_features)
|
|
|
- result.append(pred)
|
|
|
-
|
|
|
- loss_wirepoint = {}
|
|
|
+ if targets is not None:
|
|
|
+ rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
|
|
|
+ loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
|
|
|
+ else:
|
|
|
+ print(f'model inference!!!')
|
|
|
+ pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
|
|
|
+ result.append(line_features)
|
|
|
+ result.append(pred)
|
|
|
+ loss_wirepoint = {}
|
|
|
|
|
|
losses.update(loss_wirepoint)
|
|
|
else:
|