|
@@ -3,6 +3,7 @@ from matplotlib import pyplot as plt
|
|
|
|
|
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
+from torch.cuda import device
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiceLoss(nn.Module):
|
|
class DiceLoss(nn.Module):
|
|
@@ -438,12 +439,211 @@ def heatmaps_to_points(maps, rois,num_points=2):
|
|
|
|
|
|
|
|
return point_preds,point_end_scores
|
|
return point_preds,point_end_scores
|
|
|
|
|
|
|
|
|
|
+# 分4块
|
|
|
|
|
+def find_max_heat_point_in_each_part(feature_map, box):
|
|
|
|
|
+ """
|
|
|
|
|
+ 在给定的特征图上,根据box中心点往上移3,往右移3作为新的中心点,
|
|
|
|
|
+ 并将特征图划分为4个部分,之后在每个部分中找到热度值最大的点。
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ feature_map (torch.Tensor): 形状为 [C, H, W] 的特征图
|
|
|
|
|
+ box (torch.Tensor): 形状为 [4] 的边界框 [x_min, y_min, x_max, y_max]
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list: 每个区域中热度最高的点的位置和其对应的热度值 [(y1, x1, heat1), ..., (y4, x4, heat4)]
|
|
|
|
|
+ """
|
|
|
|
|
+ device = feature_map.device
|
|
|
|
|
+ C, H, W = feature_map.shape
|
|
|
|
|
+
|
|
|
|
|
+ # 计算box的中心点(cx, cy)
|
|
|
|
|
+ cx = (box[0] + box[2]) // 2
|
|
|
|
|
+ cy = (box[1] + box[3]) // 2
|
|
|
|
|
+
|
|
|
|
|
+ # 偏移中心点
|
|
|
|
|
+ new_cx = min(max(cx + 3, 0), W - 1) # 向右移3
|
|
|
|
|
+ new_cy = min(max(cy - 3, 0), H - 1) # 向上移3
|
|
|
|
|
+
|
|
|
|
|
+ # 创建坐标网格
|
|
|
|
|
+ y_coords, x_coords = torch.meshgrid(
|
|
|
|
|
+ torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij'
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 划分四个区域
|
|
|
|
|
+ mask_q1 = (y_coords < new_cy) & (x_coords < new_cx) # 左上
|
|
|
|
|
+ mask_q2 = (y_coords < new_cy) & (x_coords >= new_cx) # 右上
|
|
|
|
|
+ mask_q3 = (y_coords >= new_cy) & (x_coords < new_cx) # 左下
|
|
|
|
|
+ mask_q4 = (y_coords >= new_cy) & (x_coords >= new_cx) # 右下
|
|
|
|
|
+
|
|
|
|
|
+ # def process_region(mask):
|
|
|
|
|
+ # region = feature_map[:, :, mask].squeeze()
|
|
|
|
|
+ # if len(region.shape) == 0: # 如果区域为空,则跳过
|
|
|
|
|
+ # return None, None
|
|
|
|
|
+ # # 找到最大热度值的点及其位置
|
|
|
|
|
+ # (y, x), heat_val = non_maximum_suppression(region[0])
|
|
|
|
|
+ # # 将相对坐标转换回全局坐标
|
|
|
|
|
+ # y_global = y + torch.where(mask)[0].min().item()
|
|
|
|
|
+ # x_global = x + torch.where(mask)[1].min().item()
|
|
|
|
|
+ # return (y_global, x_global), heat_val
|
|
|
|
|
+ #
|
|
|
|
|
+ # results = []
|
|
|
|
|
+ # for mask in [mask_q1, mask_q2, mask_q3, mask_q4]:
|
|
|
|
|
+ # point, heat_val = process_region(mask)
|
|
|
|
|
+ # if point is not None:
|
|
|
|
|
+ # # results.append((point[0], point[1], heat_val))
|
|
|
|
|
+ # results.append((point[0], point[1]))
|
|
|
|
|
+ # else:
|
|
|
|
|
+ # results.append(None)
|
|
|
|
|
+ masks = [mask_q1, mask_q2, mask_q3, mask_q4]
|
|
|
|
|
+ results = []
|
|
|
|
|
+
|
|
|
|
|
+ # 假设使用第一个通道作为热力图
|
|
|
|
|
+ heatmap = feature_map[0] # [H, W]
|
|
|
|
|
+
|
|
|
|
|
+ def process_region(mask):
|
|
|
|
|
+ # 应用 mask,只保留该区域
|
|
|
|
|
+ masked_heatmap = heatmap.clone() # 复制以避免修改原数据
|
|
|
|
|
+ masked_heatmap[~mask] = 0 # 非区域置0
|
|
|
|
|
+
|
|
|
|
|
+ def non_maximum_suppression_2d(heatmap, kernel_size=3):
|
|
|
|
|
+ """
|
|
|
|
|
+ 对 2D 热力图做非极大值抑制,保留局部最大值点。
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ heatmap (torch.Tensor): [H, W],输入热力图
|
|
|
|
|
+ kernel_size (int): 池化窗口大小,用于比较是否为局部最大值
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ torch.Tensor: 与 heatmap 同形状的 mask,局部最大值位置为 True
|
|
|
|
|
+ """
|
|
|
|
|
+ pad = (kernel_size - 1) // 2
|
|
|
|
|
+ max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
|
|
|
|
|
+ maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
|
|
|
|
|
+ # 局部最大值且值大于0
|
|
|
|
|
+ peaks = (heatmap == maxima) & (heatmap > 0)
|
|
|
|
|
+ return peaks
|
|
|
|
|
+
|
|
|
|
|
+ # 1. 先做 NMS 得到候选局部极大值点
|
|
|
|
|
+ nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3) # [H, W] bool
|
|
|
|
|
+ candidate_peaks = masked_heatmap * nms_mask.float() # 只保留 NMS 后的峰值
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 找出所有候选点中值最大的一个
|
|
|
|
|
+ if candidate_peaks.max() <= 0:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ # 找到最大值的位置
|
|
|
|
|
+ max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
|
|
|
|
|
+ y, x = divmod(max_idx.item(), W)
|
|
|
|
|
+
|
|
|
|
|
+ return (x, y) # 返回 (y, x)
|
|
|
|
|
+
|
|
|
|
|
+ for mask in masks:
|
|
|
|
|
+ point = process_region(mask)
|
|
|
|
|
+ results.append(point)
|
|
|
|
|
+
|
|
|
|
|
+ return results
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def non_maximum_suppression_2d(heatmap, kernel_size=3):
|
|
|
|
|
+ pad = (kernel_size - 1) // 2
|
|
|
|
|
+ max_pool = torch.nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)
|
|
|
|
|
+ maxima = max_pool(heatmap.unsqueeze(0)).squeeze(0)
|
|
|
|
|
+ peaks = (heatmap == maxima) & (heatmap > 0)
|
|
|
|
|
+ return peaks
|
|
|
|
|
+
|
|
|
|
|
+def find_max_heat_point_in_edge_centers(feature_map, box):
|
|
|
|
|
+
|
|
|
|
|
+ device = feature_map.device
|
|
|
|
|
+ C, H, W = feature_map.shape
|
|
|
|
|
+
|
|
|
|
|
+ # ¼ÆËã box ÖÐÐÄ
|
|
|
|
|
+ cx = (box[0] + box[2]) / 2
|
|
|
|
|
+ cy = (box[1] + box[3]) / 2
|
|
|
|
|
+
|
|
|
|
|
+ # ¸ù¾Ý box ¿í¸ß¼ÆËã¾Å¹¬¸ñ·Ö½çÏß
|
|
|
|
|
+ box_width = box[2] - box[0]
|
|
|
|
|
+ box_height = box[3] - box[1]
|
|
|
|
|
+
|
|
|
|
|
+ x_left = cx - box_width / 6
|
|
|
|
|
+ x_right = cx + box_width / 6
|
|
|
|
|
+ y_top = cy - box_height / 6
|
|
|
|
|
+ y_bottom = cy + box_height / 6
|
|
|
|
|
+
|
|
|
|
|
+ # ´´½¨Íø¸ñ
|
|
|
|
|
+ y_coords, x_coords = torch.meshgrid(
|
|
|
|
|
+ torch.arange(H, device=device),
|
|
|
|
|
+ torch.arange(W, device=device),
|
|
|
|
|
+ indexing='ij'
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # ¶¨ÒåËĸö¡°±ßÖС±ÇøÓòµÄ mask
|
|
|
|
|
+ mask1 = (x_coords < x_left) & (y_coords < y_top)
|
|
|
|
|
+ mask_top_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords < y_top)
|
|
|
|
|
+ mask3 = (x_coords >= x_right) & (y_coords < y_top)
|
|
|
|
|
+
|
|
|
|
|
+ mask_left_middle = (x_coords < x_left) & (y_coords >= y_top) & (y_coords < y_bottom)
|
|
|
|
|
+ mask_right_middle = (x_coords >= x_right) & (y_coords >= y_top) & (y_coords < y_bottom)
|
|
|
|
|
+
|
|
|
|
|
+ mask4 = (x_coords < x_left) & (y_coords >= y_bottom)
|
|
|
|
|
+ mask_bottom_middle = (x_coords >= x_left) & (x_coords < x_right) & (y_coords >= y_bottom)
|
|
|
|
|
+ mask_right_bottom = (x_coords >= x_right) & (y_coords >= y_bottom)
|
|
|
|
|
+
|
|
|
|
|
+ # masks = [
|
|
|
|
|
+ # # mask1,
|
|
|
|
|
+ # mask_top_middle,
|
|
|
|
|
+ # # mask3,
|
|
|
|
|
+ # mask_left_middle,
|
|
|
|
|
+ # mask_right_middle,
|
|
|
|
|
+ # # mask4,
|
|
|
|
|
+ # mask_bottom_middle,
|
|
|
|
|
+ # mask_right_bottom
|
|
|
|
|
+ # ]
|
|
|
|
|
+
|
|
|
|
|
+ masks = [
|
|
|
|
|
+ mask_top_middle,
|
|
|
|
|
+ mask_right_middle,
|
|
|
|
|
+ mask_bottom_middle,
|
|
|
|
|
+ mask_left_middle
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # ʹÓõÚÒ»¸öͨµÀ×÷ΪÈÈÁ¦Í¼
|
|
|
|
|
+ heatmap = feature_map[0] # [H, W]
|
|
|
|
|
+
|
|
|
|
|
+ results = []
|
|
|
|
|
+
|
|
|
|
|
+ for mask in masks:
|
|
|
|
|
+ masked_heatmap = heatmap.clone()
|
|
|
|
|
+ masked_heatmap[~mask] = 0 # ·ÇÄ¿±êÇøÓòÖà 0
|
|
|
|
|
+
|
|
|
|
|
+ # # NMS ÒÖÖÆ
|
|
|
|
|
+ # nms_mask = non_maximum_suppression_2d(masked_heatmap, kernel_size=3)
|
|
|
|
|
+ # candidate_peaks = masked_heatmap * nms_mask.float()
|
|
|
|
|
+ #
|
|
|
|
|
+ # if candidate_peaks.max() <= 0:
|
|
|
|
|
+ # results.append(None)
|
|
|
|
|
+ # continue
|
|
|
|
|
+ #
|
|
|
|
|
+ # # ÕÒ×î´óֵλÖÃ
|
|
|
|
|
+ # max_val, max_idx = torch.max(candidate_peaks.view(-1), dim=0)
|
|
|
|
|
+ # y, x = divmod(max_idx.item(), W)
|
|
|
|
|
+ flatten_point_roi_map = masked_heatmap.reshape(1, -1)
|
|
|
|
|
+ point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
|
|
|
|
|
+ point_x =point_index % W
|
|
|
|
|
+ point_y = torch.div(point_index - point_x, W, rounding_mode="floor")
|
|
|
|
|
+
|
|
|
|
|
+ results.append((point_x, point_y))
|
|
|
|
|
+
|
|
|
|
|
+ return results # [(y_top, x_top), (y_right, x_right), (y_bottom, x_bottom), (y_left, x_left)]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def heatmaps_to_circle_points(maps, rois,num_points=2):
|
|
def heatmaps_to_circle_points(maps, rois,num_points=2):
|
|
|
|
|
|
|
|
|
|
|
|
|
point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
|
|
point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
|
|
|
point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
|
|
point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
|
|
|
|
|
|
|
|
|
|
+ print(f'rois in heatmaps_to_circle_points:{type(rois), rois.shape}') # <class 'torch.Tensor'>
|
|
|
|
|
+
|
|
|
print(f'heatmaps_to_lines:{maps.shape}')
|
|
print(f'heatmaps_to_lines:{maps.shape}')
|
|
|
point_maps=maps[:,0]
|
|
point_maps=maps[:,0]
|
|
|
print(f'point_map:{point_maps.shape}')
|
|
print(f'point_map:{point_maps.shape}')
|
|
@@ -452,18 +652,24 @@ def heatmaps_to_circle_points(maps, rois,num_points=2):
|
|
|
point_roi_map = point_maps[i].unsqueeze(0)
|
|
point_roi_map = point_maps[i].unsqueeze(0)
|
|
|
print(f'point_roi_map:{point_roi_map.shape}')
|
|
print(f'point_roi_map:{point_roi_map.shape}')
|
|
|
# roi_map_probs = scores_to_probs(roi_map.copy())
|
|
# roi_map_probs = scores_to_probs(roi_map.copy())
|
|
|
- w = point_roi_map.shape[2]
|
|
|
|
|
- flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
|
|
|
|
|
- point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
|
|
|
|
|
- print(f'point index:{point_index}')
|
|
|
|
|
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
|
|
|
|
|
-
|
|
|
|
|
- point_x =point_index % w
|
|
|
|
|
- point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
- point_preds[i, :,0] = point_x
|
|
|
|
|
- point_preds[i, :,1] = point_y
|
|
|
|
|
|
|
+ # w = point_roi_map.shape[2]
|
|
|
|
|
+ # flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
|
|
|
|
|
+ # print(f'non_maximum_suppression :{non_maximum_suppression(point_roi_map).shape}')
|
|
|
|
|
+ # point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
|
|
|
|
|
+ # print(f'point index:{point_index}')
|
|
|
|
|
+ # point_x =point_index % w
|
|
|
|
|
+ # point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
|
|
|
|
|
+ # print(f'point_x:{point_x}, point_y:{point_y}')
|
|
|
|
|
+ # point_preds[i, :, 0] = point_x
|
|
|
|
|
+ # point_preds[i, :, 1] = point_y
|
|
|
|
|
+ roi1=rois[i]
|
|
|
|
|
+ result_points = find_max_heat_point_in_edge_centers(non_maximum_suppression(point_roi_map), roi1)
|
|
|
|
|
+
|
|
|
|
|
+ point_preds[i, :]=torch.tensor(result_points)
|
|
|
|
|
+
|
|
|
|
|
+ point_x = [point[0] for point in result_points]
|
|
|
|
|
+ point_y = [point[1] for point in result_points]
|
|
|
|
|
|
|
|
point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
|
|
point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
|
|
|
|
|
|