Pārlūkot izejas kodu

调试训练keypointrcnn

RenLiqiang 5 mēneši atpakaļ
vecāks
revīzija
9b3b098aef

+ 7 - 1
models/dataset_tool.py

@@ -224,11 +224,17 @@ def line_boxes(target):
     lines = lpre
     sline = np.ones(lpre.shape[0])
 
+    keypoints = []
+
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b), s) in enumerate(zip(lines, sline)):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            keypoints.append([a[0], b[0]])
+            keypoints.append([a[1], b[1]])
+
             if a[1] > b[1]:
                 ymax = a[1] + 1
                 ymin = b[1] - 1
@@ -243,7 +249,7 @@ def line_boxes(target):
                 xmax = b[0] + 1
             boxs.append([ymin, xmin, ymax, xmax])
 
-    return torch.tensor(boxs)
+    return torch.tensor(boxs), torch.tensor(keypoints)
 
 
 def read_polygon_points_wire(lbl_path, shape):

+ 1 - 0
models/ins/maskrcnn.py

@@ -125,6 +125,7 @@ class MaskRCNNModel(nn.Module):
 
             # 创建彩色掩码
             colored_mask = np.zeros_like(image)
+
             colored_mask[:] = color
             colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
 

+ 9 - 8
models/keypoint/keypoint_dataset.py

@@ -66,7 +66,7 @@ class KeypointDataset(BaseDataset):
         return len(self.imgs)
 
     def read_target(self, item, lbl_path, shape, extra=None):
-        print(f'shape:{shape}')
+        # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
             lable_all = json.load(file)
@@ -123,17 +123,18 @@ class KeypointDataset(BaseDataset):
 
         target["labels"] = torch.stack(labels)
         # print(f'labels:{target["labels"]}')
-        target["boxes"] = line_boxes(target)
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
-        keypoints= wire_labels["junc_coords"]
-        keypoints[:,2]=2
-        # keypoints[:,0]=keypoints[:,0]/shape[0]
-        # keypoints[:, 1] = keypoints[:, 1] / shape[1]
-        target["keypoints"]=keypoints
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
         # 在 __getitem__ 方法中调用此函数
         validate_keypoints(keypoints, shape[0], shape[1])
-        print(f'keypoints:{target["keypoints"].shape}')
+        # print(f'keypoints:{target["keypoints"].shape}')
         return target
 
     def show(self, idx):

+ 5 - 1
models/keypoint/test.py

@@ -58,4 +58,8 @@ keypoints = kpts[idx]
 
 res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
 show(res)
-plt.show()
+plt.show()
+
+
+
+

+ 7 - 6
models/wirenet/head.py

@@ -1103,8 +1103,9 @@ class RoIHeads(nn.Module):
             losses.update(loss_keypoint)
 
         if self.has_wirepoint():
-            # print(f'result:{result}')
+            print(f'wirepoint result:{result}')
             wirepoint_proposals = [p["boxes"] for p in result]
+
             if self.training:
                 # during training, only focus on positive boxes
                 num_images = len(proposals)
@@ -1121,20 +1122,20 @@ class RoIHeads(nn.Module):
                 pos_matched_idxs = None
 
             # print(f'proposals:{len(proposals)}')
+            print(f'wirepoint_proposals:{wirepoint_proposals}')
             wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
 
             # tmp = keypoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
-            # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
+            print(f'wirepoint_features from roi_pool:{wirepoint_features.shape}')
             outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
-
-
+            print(f'outputs1 from head:{outputs.shape}')
 
             outputs = merge_features(outputs, wirepoint_proposals)
 
 
 
-            wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
+            # wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
 
             print(f'outpust:{outputs.shape}')
 
@@ -1208,7 +1209,7 @@ class RoIHeads(nn.Module):
 #     return merged_features
 
 def merge_features(features, proposals):
-    print(f'features:{features.shape}')
+    print(f'features in merge_features:{features.shape}')
     print(f'proposals:{len(proposals)}')
     def diagnose_input(features, proposals):
         """诊断输入数据"""

+ 14 - 0
models/wirenet/test_mask.py

@@ -0,0 +1,14 @@
+import torch
+from matplotlib import pyplot as plt
+
+img=torch.ones((128,128,3))
+mask=torch.zeros((128,128,3))
+
+mask[0:30,:,:]=1
+
+
+img[mask==1]=0
+
+
+plt.imshow(img)
+plt.show()

+ 3 - 0
models/wirenet/wirepoint_rcnn.py

@@ -430,7 +430,9 @@ class WirepointPredictor(nn.Module):
             Lneg = meta["line_neg_idx"]
 
             n_type = jmap.shape[0]
+            print(f'jmap:{jmap.shape}')
             jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+
             joff = joff.reshape(n_type, 2, -1)
             max_K = self.n_dyn_junc // n_type
             N = len(junc)
@@ -812,6 +814,7 @@ if __name__ == '__main__':
         model.train()
 
         for imgs, targets in data_loader_train:
+            print(f'targets:{targets[0]["wires"]["line_map"].shape}')
             losses = model(move_to_device(imgs, device), move_to_device(targets, device))
             loss = _loss(losses)
             print(loss)