Pārlūkot izejas kodu

修改线段采样函数与前向传播采样方法不一致导致推理线段效果不好的bug

RenLiqiang 8 mēneši atpakaļ
vecāks
revīzija
d986c2687c
2 mainītis faili ar 78 papildinājumiem un 39 dzēšanām
  1. 77 38
      models/line_detect/line_predictor.py
  2. 1 1
      models/line_detect/roi_heads.py

+ 77 - 38
models/line_detect/line_predictor.py

@@ -132,7 +132,7 @@ class LineRCNNPredictor(nn.Module):
         else:
             self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
             self.fc2 = nn.Sequential(
-                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.Linear(self.dim_loi *  FEATURE_DIM, self.dim_fc),
                 nn.ReLU(inplace=True),
                 nn.Linear(self.dim_fc, self.dim_fc),
                 nn.ReLU(inplace=True),
@@ -213,10 +213,11 @@ class LineRCNNPredictor(nn.Module):
         print(f'features shape:{features.shape}')
         print(f'inputs shape :{inputs.shape}')
         # x = self.fc1(features)
-        x = inputs[:,2:3,:,:].sigmoid()
-        print(f'x:{x.shape}')
+        lmap = inputs[:,2:3,:,:].sigmoid()
+        x=lmap
+        print(f'x:{lmap.shape}')
 
-        n_batch, n_channel, row, col = x.shape
+        n_batch, n_channel, row, col = lmap.shape
         # n_batch, n_channel, row, col = x.shape
 
 
@@ -226,9 +227,9 @@ class LineRCNNPredictor(nn.Module):
 
         for i, meta in enumerate(wires_targets):
             p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i],
+                meta, h["jmap"][i], h["joff"][i],lmap[i]
             )
-            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
             ys.append(label)
             if self.training and self.do_static_sampling:
                 p = torch.cat([p, meta["lpre"]])
@@ -239,43 +240,47 @@ class LineRCNNPredictor(nn.Module):
                 jcs.append(jc)
                 ps.append(p)
             fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                .reshape(n_channel, -1, self.n_pts0)
-                .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            # print(f'xp.shape:{xp.shape}')
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
+            #
+            # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            # p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            # px0 = px.floor().clamp(min=0, max=127)
+            # py0 = py.floor().clamp(min=0, max=127)
+            # px1 = (px0 + 1).clamp(min=0, max=127)
+            # py1 = (py0 + 1).clamp(min=0, max=127)
+            # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+            #
+            # # xp: [N_LINE, N_CHANNEL, N_POINT]
+            # xp = (
+            #     (
+            #             x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+            #             + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+            #             + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+            #             + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+            #     )
+            #     .reshape(n_channel, -1, self.n_pts0)
+            #     .permute(1, 0, 2)
+            # )
+            # xp = self.pooling(xp)
+            # # print(f'xp.shape:{xp.shape}')
+            # xs.append(xp)
+            idx.append(idx[-1] + feat.shape[0])
             # print(f'idx__:{idx}')
 
-        x, y = torch.cat(xs), torch.cat(ys)
+        # x, y = torch.cat(xs), torch.cat(ys)
+        y=torch.cat(ys)
         f = torch.cat(fs)
-        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+        print(f'f:{f.shape}')
+        # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
 
         # print("Weight dtype:", self.fc2.weight.dtype)
-        x = torch.cat([x, f], 1)
+        # x = torch.cat([x, f], 1)
+        # print(f'x3:{x.shape}')
         # print("Input dtype:", x.dtype)
-        x = x.to(dtype=torch.float32)
+        f= f.to(dtype=torch.float32)
+        # x = x.to(dtype=torch.float32)
         # print("Input dtype1:", x.dtype)
-        x = self.fc2(x).flatten()
+        x = self.fc2(f).flatten()
 
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
         return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
@@ -285,7 +290,7 @@ class LineRCNNPredictor(nn.Module):
 
         # return result
 
-    def sample_lines(self, meta, jmap, joff):
+    def sample_lines(self, meta, jmap, joff,lmap):
         device = jmap.device
         with torch.no_grad():
             junc = meta["junc_coords"].to(device)  # [N, 2]
@@ -367,6 +372,10 @@ class LineRCNNPredictor(nn.Module):
 
             u2v = xyu - xyv
             u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+
+
+            # print(f'xp.shape:{xp.shape}')
+
             feat = torch.cat(
                 [
                     xyu / 128 * self.use_cood,
@@ -377,11 +386,41 @@ class LineRCNNPredictor(nn.Module):
                 ],
                 1,
             )
+            print(f'feat  shape:{feat.shape}')
+
             line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+            # print(f'line:{line.shape}')
+            n_channel, row, col = lmap.shape
+            p=line
+            print(f'p.shape :{p.shape}')
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            x=lmap
+            xp = (
+                (
+                        x[ :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[ :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[ :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[ :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp).squeeze(1)
+            print(f'xp shape:{xp.shape}')
+
 
             xy = xy.reshape(n_type, K, 2)
             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
+            return line, label.float(), xp, jcs
 
 
 

+ 1 - 1
models/line_detect/roi_heads.py

@@ -253,7 +253,7 @@ def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
         result["preds"]["junts"] = torch.cat(
             [jcs[i][1] for i in range(n_batch)]
         )
-
+    print(f'predic result:{result}')
     return result