|
@@ -347,20 +347,11 @@ class WirepointPredictor(nn.Module):
|
|
|
"lmap": lmap.sigmoid(),
|
|
|
"joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
|
|
|
}
|
|
|
- # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
|
|
|
- # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
|
|
|
- # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
|
|
|
|
|
|
h = result["preds"]
|
|
|
- print(f'features shape:{features.shape}')
|
|
|
+ # print(f'features shape:{features.shape}')
|
|
|
x = self.fc1(features)
|
|
|
-
|
|
|
- # print(f'x:{x.shape}')
|
|
|
-
|
|
|
n_batch, n_channel, row, col = x.shape
|
|
|
-
|
|
|
- # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
|
|
|
-
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
|
|
|
for i, meta in enumerate(wires_targets):
|
|
@@ -408,12 +399,9 @@ class WirepointPredictor(nn.Module):
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
-
|
|
|
- # print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
+ print(f"pstest{ps}")
|
|
|
x = torch.cat([x, f], 1)
|
|
|
- # print("Input dtype:", x.dtype)
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
- # print("Input dtype1:", x.dtype)
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
@@ -424,47 +412,6 @@ class WirepointPredictor(nn.Module):
|
|
|
|
|
|
# return result
|
|
|
|
|
|
- ####deprecated
|
|
|
- # def inference(self,input, idx, jcs, n_batch, ps):
|
|
|
- # if not self.training:
|
|
|
- # p = torch.cat(ps)
|
|
|
- # s = torch.sigmoid(input)
|
|
|
- # b = s > 0.5
|
|
|
- # lines = []
|
|
|
- # score = []
|
|
|
- # print(f"n_batch:{n_batch}")
|
|
|
- # for i in range(n_batch):
|
|
|
- # print(f"idx:{idx}")
|
|
|
- # p0 = p[idx[i]: idx[i + 1]]
|
|
|
- # s0 = s[idx[i]: idx[i + 1]]
|
|
|
- # mask = b[idx[i]: idx[i + 1]]
|
|
|
- # p0 = p0[mask]
|
|
|
- # s0 = s0[mask]
|
|
|
- # if len(p0) == 0:
|
|
|
- # lines.append(torch.zeros([1, self.n_out_line, 2, 2], device=p.device))
|
|
|
- # score.append(torch.zeros([1, self.n_out_line], device=p.device))
|
|
|
- # else:
|
|
|
- # arg = torch.argsort(s0, descending=True)
|
|
|
- # p0, s0 = p0[arg], s0[arg]
|
|
|
- # lines.append(p0[None, torch.arange(self.n_out_line) % len(p0)])
|
|
|
- # score.append(s0[None, torch.arange(self.n_out_line) % len(s0)])
|
|
|
- # for j in range(len(jcs[i])):
|
|
|
- # if len(jcs[i][j]) == 0:
|
|
|
- # jcs[i][j] = torch.zeros([self.n_out_junc, 2], device=p.device)
|
|
|
- # jcs[i][j] = jcs[i][j][
|
|
|
- # None, torch.arange(self.n_out_junc) % len(jcs[i][j])
|
|
|
- # ]
|
|
|
- # result["preds"]["lines"] = torch.cat(lines)
|
|
|
- # result["preds"]["score"] = torch.cat(score)
|
|
|
- # result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
|
|
|
- #
|
|
|
- # if len(jcs[i]) > 1:
|
|
|
- # result["preds"]["junts"] = torch.cat(
|
|
|
- # [jcs[i][1] for i in range(n_batch)]
|
|
|
- # )
|
|
|
- # if self.training:
|
|
|
- # del result["preds"]
|
|
|
-
|
|
|
def sample_lines(self, meta, jmap, joff):
|
|
|
with torch.no_grad():
|
|
|
junc = meta["junc_coords"] # [N, 2]
|
|
@@ -631,21 +578,21 @@ if __name__ == '__main__':
|
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
|
- dataset_train, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
|
|
|
)
|
|
|
|
|
|
dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
- val_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
|
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
|
- dataset_val, batch_sampler=val_batch_sampler, num_workers=4, collate_fn=val_collate_fn
|
|
|
+ dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
model = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
|
|
|
- optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
writer = SummaryWriter(cfg['io']['logdir'])
|
|
|
|
|
|
|
|
@@ -682,11 +629,12 @@ if __name__ == '__main__':
|
|
|
optimizer.step()
|
|
|
writer_loss(writer, losses)
|
|
|
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for imgs, targets in dataset_val:
|
|
|
- pred = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
-
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for imgs, targets in data_loader_val:
|
|
|
+ print(111)
|
|
|
+ pred = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+ print(f"pred:{pred}")
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|