|
|
@@ -1372,9 +1372,12 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
print(f'features from backbone:{features['0'].shape}')
|
|
|
feature_logits = self.ins_forward1(features, image_shapes, ins_proposals)
|
|
|
+ arc_equation = self.arc_equation_head(feature_logits) # [proposal和,7]
|
|
|
|
|
|
loss_ins = None
|
|
|
loss_ins_extra=None
|
|
|
+ loss_arc_equation = None
|
|
|
+ loss_arc_ends = None
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
@@ -1384,6 +1387,10 @@ class RoIHeads(nn.Module):
|
|
|
gt_inses = [t["circle_masks"] for t in targets if "circle_masks" in t]
|
|
|
gt_labels = [t["labels"] for t in targets]
|
|
|
|
|
|
+ gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
|
|
|
+ gt_mask_ends = [t["mask_ends"] for t in targets if "mask_ends" in t]
|
|
|
+ gt_mask_params = [t["mask_params"] for t in targets if "mask_params" in t]
|
|
|
+
|
|
|
print(f'gt_ins:{gt_inses[0].shape}')
|
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
|
@@ -1396,9 +1403,21 @@ class RoIHeads(nn.Module):
|
|
|
if gt_ins_tensor.shape[0] > 0:
|
|
|
print(f'start to compute circle_loss')
|
|
|
|
|
|
- loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses, ins_pos_matched_idxs)
|
|
|
-
|
|
|
- # loss_ins_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
+ loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
|
|
|
+ total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,
|
|
|
+ ins_proposals,
|
|
|
+ gt_mask_ends,
|
|
|
+ gt_mask_params,
|
|
|
+ ins_pos_matched_idxs,
|
|
|
+ labels)
|
|
|
+ loss_arc_ends = loss_arc_ends
|
|
|
+ if loss_arc_equation is None:
|
|
|
+ print(f'loss_arc_equation is None')
|
|
|
+ loss_arc_equation = torch.tensor(0.0, device=device)
|
|
|
+
|
|
|
+ if loss_arc_ends is None:
|
|
|
+ print(f'loss_arc_ends is None')
|
|
|
+ loss_arc_ends = torch.tensor(0.0, device=device)
|
|
|
|
|
|
if loss_ins is None:
|
|
|
print(f'loss_ins is None111')
|
|
|
@@ -1408,7 +1427,7 @@ class RoIHeads(nn.Module):
|
|
|
print(f'loss_ins_extra is None111')
|
|
|
loss_ins_extra = torch.tensor(0.0, device=device)
|
|
|
|
|
|
- loss_ins = {"loss_ins": loss_ins}
|
|
|
+ loss_ins = {"loss_ins": loss_ins,"loss_arc_equation": loss_arc_equation,"loss_arc_ends": loss_arc_ends}
|
|
|
loss_ins_extra = {"loss_ins_extra": loss_ins_extra}
|
|
|
|
|
|
else:
|
|
|
@@ -1715,3 +1734,140 @@ class RoIHeads(nn.Module):
|
|
|
if roi_features is not None:
|
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
|
return roi_features
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+
|
|
|
+def compute_arc_equation_loss(arc_equation, proposals, gt_mask_ends, gt_mask_params, arc_pos_matched_idxs,
|
|
|
+ gt_labels_all):
|
|
|
+ """
|
|
|
+ Compute loss between predicted arc equations and ground truth.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ arc_equation: list of length B, each Tensor (N_i, 7)
|
|
|
+ gt_mask_ends: GT arc end masks (for angle calculation)
|
|
|
+ gt_mask_params: list of length B, each numpy array (num_gt, 5)
|
|
|
+ arc_pos_matched_idxs: list of length B, each Tensor of indices matching predictions to GT
|
|
|
+ gt_labels_all: list of length B, GT labels
|
|
|
+ """
|
|
|
+ len_proposals = len(proposals) # batch
|
|
|
+ device = arc_equation[0].device
|
|
|
+ print(
|
|
|
+ f'compute_arc_equation_loss line_logits.shape:{arc_equation.shape},len_proposals:{len_proposals},line_matched_idxs:{arc_pos_matched_idxs}')
|
|
|
+ print(f'gt_mask_ends:{gt_mask_ends}, gt_mask_params:{gt_mask_params}')
|
|
|
+
|
|
|
+ gt_angles = []
|
|
|
+ # for gt_mask_end,gt_mask_param in zip(gt_mask_ends, gt_mask_params):
|
|
|
+ # print(f'gt_mask_end:{gt_mask_end}, gt_mask_param:{gt_mask_param}')
|
|
|
+ # gt_angles.append(compute_arc_angles(gt_mask_end,gt_mask_param))
|
|
|
+ for i in range(len(gt_mask_ends)):
|
|
|
+ print(f'gt_mask_end:{gt_mask_ends[i]}, gt_mask_param:{gt_mask_params[i]}')
|
|
|
+ gt_angles.append(compute_arc_angles(gt_mask_ends[i], gt_mask_params[i]))
|
|
|
+
|
|
|
+ print(f'gt_angles:{gt_angles}')
|
|
|
+ print(f'gt_mask_params:{gt_mask_params}')
|
|
|
+ print(f'gt_labels_all:{gt_labels_all}')
|
|
|
+ print(f'arc_pos_matched_idxs:{arc_pos_matched_idxs}')
|
|
|
+
|
|
|
+ gt_sel_params = []
|
|
|
+ gt_sel_angles = []
|
|
|
+ for proposals_per_image, gt_angle, gt_params, gt_label, midx in zip(proposals, gt_angles, gt_mask_params,
|
|
|
+ gt_labels_all, arc_pos_matched_idxs):
|
|
|
+ print(f'line_proposals_per_image:{proposals_per_image.shape}')
|
|
|
+ # gt_angle = torch.tensor(gt_angle)
|
|
|
+ gt_angle = torch.stack(gt_angle, dim=0)
|
|
|
+ gt_params = torch.tensor(gt_params)
|
|
|
+ if gt_angle.shape[0] > 0:
|
|
|
+ # positions = (gt_label == 3).nonzero()[0].item()
|
|
|
+
|
|
|
+ po = gt_angle[midx.cpu()]
|
|
|
+ pa = gt_params[midx.cpu()]
|
|
|
+ print(f'po:{po},pa:{pa}')
|
|
|
+
|
|
|
+ gt_sel_angles.append(po)
|
|
|
+ gt_sel_params.append(pa)
|
|
|
+
|
|
|
+ print(f'gt_sel_angles:{gt_sel_angles}')
|
|
|
+ print(f'gt_sel_params:{gt_sel_params}')
|
|
|
+
|
|
|
+ gt_sel_angles = torch.cat(gt_sel_angles, dim=0)
|
|
|
+ gt_sel_params = torch.cat(gt_sel_params, dim=0)
|
|
|
+
|
|
|
+ print(f'gt_sel_angles:{gt_sel_angles}')
|
|
|
+ print(f'gt_sel_params:{gt_sel_params}')
|
|
|
+
|
|
|
+ pred_angles = arc_equation[:, 5:7]
|
|
|
+ pred_params = arc_equation[:, :5]
|
|
|
+
|
|
|
+ angle_loss = F.mse_loss(pred_angles, gt_sel_angles)
|
|
|
+ param_loss = F.mse_loss(pred_params.cpu(), gt_sel_params) / 10000
|
|
|
+ print(f'angle_loss:{angle_loss}, param_loss:{param_loss}')
|
|
|
+
|
|
|
+ count = sum(len(sublist) for sublist in proposals)
|
|
|
+
|
|
|
+ total_loss = (param_loss + angle_loss) / count if count > 0 else torch.tensor(0.0)
|
|
|
+
|
|
|
+ # 确保 dtype 和 device
|
|
|
+ total_loss = total_loss.float().to(device)
|
|
|
+ angle_loss = angle_loss.float().to(device)
|
|
|
+ param_loss = param_loss.float().to(device)
|
|
|
+
|
|
|
+ # if count > 0:
|
|
|
+ # total_loss = (param_loss + angle_loss) / count
|
|
|
+ # total_loss = torch.tensor(total_loss, dtype=torch.float32, device=device)
|
|
|
+ # else:
|
|
|
+ # total_loss = torch.tensor(0.0, dtype=torch.float32, device=device)
|
|
|
+
|
|
|
+ print(f'total_loss, param_loss, angle_loss:{total_loss, param_loss, angle_loss}')
|
|
|
+
|
|
|
+ return total_loss, param_loss, angle_loss
|
|
|
+
|
|
|
+
|
|
|
+def compute_arc_angles(gt_mask_ends, gt_mask_params):
|
|
|
+ """
|
|
|
+ 给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。
|
|
|
+
|
|
|
+ Parameters:
|
|
|
+ point: tuple or array-like, (x, y)
|
|
|
+ ellipse_param: tuple or array-like, (xc, yc, a, b, theta)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ phi: float, in [0, 2*pi)
|
|
|
+ """
|
|
|
+ results = []
|
|
|
+ gt_mask_params_tensor = torch.tensor(gt_mask_params,
|
|
|
+ dtype=gt_mask_ends.dtype,
|
|
|
+ device=gt_mask_ends.device)
|
|
|
+ for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor):
|
|
|
+ # print(f'params_img:{params_img}')
|
|
|
+ if torch.norm(params_img) < 1e-6: # L2 norm near zero
|
|
|
+ results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype))
|
|
|
+ continue
|
|
|
+ x, y = ends_img
|
|
|
+ xc, yc, a, b, theta = params_img
|
|
|
+
|
|
|
+ # 1. 平移到中心
|
|
|
+ dx = x - xc
|
|
|
+ dy = y - yc
|
|
|
+
|
|
|
+ # 2. 逆旋转(旋转 -theta)
|
|
|
+ cos_t = torch.cos(theta)
|
|
|
+ sin_t = torch.sin(theta)
|
|
|
+ X = dx * cos_t + dy * sin_t
|
|
|
+ Y = -dx * sin_t + dy * cos_t
|
|
|
+
|
|
|
+ # 3. 归一化到单位圆(除以 a, b)
|
|
|
+ cos_phi = X / a
|
|
|
+ sin_phi = Y / b
|
|
|
+
|
|
|
+ # 4. 用 atan2 求角度(自动处理象限)
|
|
|
+ phi = torch.atan2(sin_phi, cos_phi)
|
|
|
+
|
|
|
+ # 5. 转换到 [0, 2π)
|
|
|
+ phi = torch.where(phi < 0, phi + 2 * torch.pi, phi)
|
|
|
+
|
|
|
+ results.append(phi)
|
|
|
+ return results
|