Преглед изворни кода

修改删除根据点的种类进行match

RenLiqiang пре 7 месеци
родитељ
комит
7ccb211c2a
3 измењених фајлова са 18 додато и 51 уклоњено
  1. 15 50
      models/line_detect/line_predictor.py
  2. 1 0
      models/line_detect/roi_heads.py
  3. 2 1
      models/line_detect/train.yaml

+ 15 - 50
models/line_detect/line_predictor.py

@@ -242,30 +242,7 @@ class LineRCNNPredictor(nn.Module):
                 jcs.append(jc)
                 ps.append(p)
             fs.append(feat)
-            #
-            # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            # p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            # px0 = px.floor().clamp(min=0, max=127)
-            # py0 = py.floor().clamp(min=0, max=127)
-            # px1 = (px0 + 1).clamp(min=0, max=127)
-            # py1 = (py0 + 1).clamp(min=0, max=127)
-            # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-            #
-            # # xp: [N_LINE, N_CHANNEL, N_POINT]
-            # xp = (
-            #     (
-            #             x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-            #             + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-            #             + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-            #             + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-            #     )
-            #     .reshape(n_channel, -1, self.n_pts0)
-            #     .permute(1, 0, 2)
-            # )
-            # xp = self.pooling(xp)
-            # # print(f'xp.shape:{xp.shape}')
-            # xs.append(xp)
+
             idx.append(idx[-1] + feat.shape[0])
             # print(f'idx__:{idx}')
 
@@ -275,10 +252,7 @@ class LineRCNNPredictor(nn.Module):
         print(f'f:{f.shape}')
         # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
 
-        # print("Weight dtype:", self.fc2.weight.dtype)
-        # x = torch.cat([x, f], 1)
-        # print(f'x3:{x.shape}')
-        # print("Input dtype:", x.dtype)
+
         f= f.to(dtype=torch.float32)
         # x = x.to(dtype=torch.float32)
         # print("Input dtype1:", x.dtype)
@@ -287,10 +261,7 @@ class LineRCNNPredictor(nn.Module):
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
         return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
 
-        # if mode != "training":
-        # self.inference(x, idx, jcs, n_batch, ps)
 
-        # return result
 
     def sample_lines(self, meta, jmap, joff,lmap):
         device = jmap.device
@@ -312,9 +283,10 @@ class LineRCNNPredictor(nn.Module):
                 print(f'jmap min:{torch.min(jmap[0])}')
                 print(f'jmap num:{(jmap > self.eval_junc_thres).float().sum().item()}')
                 print(f'jmap:{jmap}')
-                print(f'K:{K}')
+                print(f'predict K:{K}')
             else:
                 K = min(int(N * 2 + 2), max_K)
+                print(f'train  K:{K}')
             if K < 2:
                 K = 2
             device = jmap.device
@@ -331,13 +303,19 @@ class LineRCNNPredictor(nn.Module):
 
             # dist: [N_TYPE, K, N]
             dist = torch.sum((xy_ - junc) ** 2, -1)
+            print(f'dist:{dist}')
+
             cost, match = torch.min(dist, -1)
+            print(f'match:{match},cost:{cost}')
 
             # xy: [N_TYPE * K, 2]
             # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
+            # for t in range(n_type):
+            #     match[t, jtyp[match[t]] != t] = N
+
             match[cost > 1.5 * 1.5] = N
+
+            print(f'match__ : {match}')
             match = match.flatten()
 
             _ = torch.arange(n_type * K, device=device)
@@ -379,22 +357,6 @@ class LineRCNNPredictor(nn.Module):
 
             u2v = xyu - xyv
             u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-
-
-            # print(f'xp.shape:{xp.shape}')
-
-            feat = torch.cat(
-                [
-                    xyu / 128 * self.use_cood,
-                    xyv / 128 * self.use_cood,
-                    u2v * self.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            print(f'feat  shape:{feat.shape}')
-
             # lmap = gaussian_filter(lmap, sigma=1)
             # lmap = torch.from_numpy(gaussian_filter(lmap.cpu().numpy(), sigma=1)).to('cuda:0')
 
@@ -425,11 +387,14 @@ class LineRCNNPredictor(nn.Module):
                 .permute(1, 0, 2)
             )
             xp = self.pooling(xp).squeeze(1)
+            if not self.training:
+                print(f'predict  xp values:{xp}')
             print(f'xp shape:{xp.shape}')
 
 
             xy = xy.reshape(n_type, K, 2)
             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            print(f'line.shape :{line.shape}')
             return line, label.float(), xp, jcs
 
 

+ 1 - 0
models/line_detect/roi_heads.py

@@ -1055,6 +1055,7 @@ class RoIHeads(nn.Module):
             # print('has line_head')
             # outputs = self.line_head(features_lcnn)
             outputs = features_lcnn[:, 0:5, :, :]
+
             loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
                 inputs=outputs, features=features_lcnn, targets=targets)

+ 2 - 1
models/line_detect/train.yaml

@@ -1,6 +1,7 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/4_23jiagonggongjian
+#  datadir: I:/datasets/4_23jiagonggongjian
+  datadir: I:/datasets/0322_suanzaisheng
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8