|
|
@@ -1323,7 +1323,7 @@ class RoIHeads(nn.Module):
|
|
|
if matched_idxs is None:
|
|
|
raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
for img_id in range(num_images):
|
|
|
- circle_pos = torch.where(labels[img_id] == 1)[0]
|
|
|
+ circle_pos = torch.where(labels[img_id] == 4)[0]
|
|
|
circle_proposals.append(proposals[img_id][circle_pos])
|
|
|
circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
|
|
|
else:
|
|
|
@@ -1338,7 +1338,7 @@ class RoIHeads(nn.Module):
|
|
|
raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
|
|
|
for img_id in range(num_images):
|
|
|
- circle_pos = torch.where(labels[img_id] == 1)[0]
|
|
|
+ circle_pos = torch.where(labels[img_id] == 4)[0]
|
|
|
circle_proposals.append(proposals[img_id][circle_pos])
|
|
|
circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
|
|
|
|
|
|
@@ -1354,7 +1354,7 @@ class RoIHeads(nn.Module):
|
|
|
if targets is None or circle_pos_matched_idxs is None:
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
|
- gt_circles = [t["circle"] for t in targets if "circle" in t]
|
|
|
+ gt_circles = [t["circles"] for t in targets if "circles" in t]
|
|
|
|
|
|
print(f'gt_circle:{gt_circles[0].shape}')
|
|
|
h, w = targets[0]["img_size"]
|
|
|
@@ -1374,13 +1374,13 @@ class RoIHeads(nn.Module):
|
|
|
print(f'loss_circle is None111')
|
|
|
loss_circle = torch.tensor(0.0, device=device)
|
|
|
|
|
|
- loss_point = {"loss_circle": loss_circle}
|
|
|
+ loss_circle = {"loss_circle": loss_circle}
|
|
|
|
|
|
else:
|
|
|
if targets is not None:
|
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
|
- gt_circles = [t["circle"] for t in targets if "circle" in t]
|
|
|
+ gt_circles = [t["circles"] for t in targets if "circles" in t]
|
|
|
|
|
|
gt_circles_tensor = torch.zeros(0, 0)
|
|
|
if len(gt_circles) > 0:
|
|
|
@@ -1390,7 +1390,7 @@ class RoIHeads(nn.Module):
|
|
|
if gt_circles_tensor.shape[0] > 0:
|
|
|
print(f'start to compute circle_loss')
|
|
|
|
|
|
- loss_circle = compute_circle_loss(feature_logits, point_proposals, gt_circles,
|
|
|
+ loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
|
|
|
circle_pos_matched_idxs)
|
|
|
|
|
|
if loss_circle is None:
|
|
|
@@ -1599,7 +1599,7 @@ class RoIHeads(nn.Module):
|
|
|
return roi_features
|
|
|
|
|
|
def circle_forward1(self, features, image_shapes, proposals):
|
|
|
- print(f'point_proposals:{len(proposals)}')
|
|
|
+ print(f'circle_proposals:{len(proposals)}')
|
|
|
# cs_features= features['0']
|
|
|
# print(f'features-0:{features['0'].shape}')
|
|
|
# cs_features = self.channel_compress(features['0'])
|
|
|
@@ -1614,7 +1614,7 @@ class RoIHeads(nn.Module):
|
|
|
# print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
|
|
|
|
|
|
feature_logits = self.circle_predictor(cs_features)
|
|
|
- print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
|
+ print(f'feature_logits from circle_head:{feature_logits.shape}')
|
|
|
|
|
|
roi_features = features_align(feature_logits, proposals, image_shapes)
|
|
|
if roi_features is not None:
|