|
|
@@ -1050,15 +1050,16 @@ class RoIHeads(nn.Module):
|
|
|
}
|
|
|
)
|
|
|
|
|
|
- features_lcnn = features['0']
|
|
|
+ line_features = features['0']
|
|
|
if self.has_line():
|
|
|
# print('has line_head')
|
|
|
# outputs = self.line_head(features_lcnn)
|
|
|
- outputs = features_lcnn[:, 0:5, :, :]
|
|
|
+ # outputs = line_features[:, 0:5, :, :]
|
|
|
+
|
|
|
|
|
|
loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
|
|
|
x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
|
|
|
- inputs=outputs, features=features_lcnn, targets=targets)
|
|
|
+ inputs=line_features, features=line_features, targets=targets)
|
|
|
|
|
|
# # line_loss(multitasklearner)
|
|
|
# if self.training:
|
|
|
@@ -1071,12 +1072,12 @@ class RoIHeads(nn.Module):
|
|
|
# loss_weight, mode_train=False)
|
|
|
|
|
|
if self.training:
|
|
|
- rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
|
|
|
+ rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
|
|
|
loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
|
|
|
else:
|
|
|
|
|
|
pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
|
|
|
- result.append(outputs)
|
|
|
+ result.append(line_features)
|
|
|
result.append(pred)
|
|
|
loss_wirepoint = {}
|
|
|
losses.update(loss_wirepoint)
|