|
@@ -347,20 +347,11 @@ class WirepointPredictor(nn.Module):
|
|
|
"lmap": lmap.sigmoid(),
|
|
|
"joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
|
|
|
}
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
|
|
|
h = result["preds"]
|
|
|
- print(f'features shape:{features.shape}')
|
|
|
+
|
|
|
x = self.fc1(features)
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
n_batch, n_channel, row, col = x.shape
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
|
|
|
for i, meta in enumerate(wires_targets):
|
|
@@ -408,12 +399,9 @@ class WirepointPredictor(nn.Module):
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
-
|
|
|
-
|
|
|
+ print(f"pstest{ps}")
|
|
|
x = torch.cat([x, f], 1)
|
|
|
-
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
-
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
|
|
@@ -424,47 +412,6 @@ class WirepointPredictor(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def sample_lines(self, meta, jmap, joff):
|
|
|
with torch.no_grad():
|
|
|
junc = meta["junc_coords"]
|
|
@@ -631,21 +578,21 @@ if __name__ == '__main__':
|
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
|
- dataset_train, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
|
|
|
)
|
|
|
|
|
|
dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
|
|
|
- val_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
|
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
|
- dataset_val, batch_sampler=val_batch_sampler, num_workers=4, collate_fn=val_collate_fn
|
|
|
+ dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
model = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
|
|
|
- optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
writer = SummaryWriter(cfg['io']['logdir'])
|
|
|
|
|
|
|
|
@@ -682,11 +629,12 @@ if __name__ == '__main__':
|
|
|
optimizer.step()
|
|
|
writer_loss(writer, losses)
|
|
|
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for imgs, targets in dataset_val:
|
|
|
- pred = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
-
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for imgs, targets in data_loader_val:
|
|
|
+ print(111)
|
|
|
+ pred = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+ print(f"pred:{pred}")
|
|
|
|
|
|
|
|
|
|