|
@@ -26,6 +26,8 @@ from models.wirenet.head import RoIHeads
|
|
|
from models.wirenet.wirepoint_dataset import WirePointDataset
|
|
|
from tools import utils
|
|
|
|
|
|
+from torch.utils.tensorboard import SummaryWriter
|
|
|
+
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
|
|
|
@@ -290,7 +292,7 @@ class WirepointPredictor(nn.Module):
|
|
|
# outputs=merge_features(outputs,100)
|
|
|
batch, channel, row, col = inputs.shape
|
|
|
print(f'outputs:{inputs.shape}')
|
|
|
- print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
|
|
|
+ # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
|
|
|
|
|
|
if targets is not None:
|
|
|
self.training = True
|
|
@@ -334,7 +336,7 @@ class WirepointPredictor(nn.Module):
|
|
|
result = {}
|
|
|
for stack, output in enumerate([inputs]):
|
|
|
output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
|
|
|
- print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
|
|
|
+ # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
|
|
|
jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
lmap = output[offset[0]: offset[1]].squeeze(0)
|
|
|
joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
|
|
@@ -353,11 +355,11 @@ class WirepointPredictor(nn.Module):
|
|
|
print(f'features shape:{features.shape}')
|
|
|
x = self.fc1(features)
|
|
|
|
|
|
- print(f'x:{x.shape}')
|
|
|
+ # print(f'x:{x.shape}')
|
|
|
|
|
|
n_batch, n_channel, row, col = x.shape
|
|
|
|
|
|
- print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
|
|
|
+ # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
|
|
|
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
|
|
@@ -365,7 +367,7 @@ class WirepointPredictor(nn.Module):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
meta, h["jmap"][i], h["joff"][i],
|
|
|
)
|
|
|
- print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
+ # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
ys.append(label)
|
|
|
if self.training and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
@@ -398,10 +400,10 @@ class WirepointPredictor(nn.Module):
|
|
|
.permute(1, 0, 2)
|
|
|
)
|
|
|
xp = self.pooling(xp)
|
|
|
- print(f'xp.shape:{xp.shape}')
|
|
|
+ # print(f'xp.shape:{xp.shape}')
|
|
|
xs.append(xp)
|
|
|
idx.append(idx[-1] + xp.shape[0])
|
|
|
- print(f'idx__:{idx}')
|
|
|
+ # print(f'idx__:{idx}')
|
|
|
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
@@ -409,9 +411,9 @@ class WirepointPredictor(nn.Module):
|
|
|
|
|
|
# print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
x = torch.cat([x, f], 1)
|
|
|
- print("Input dtype:", x.dtype)
|
|
|
+ # print("Input dtype:", x.dtype)
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
- print("Input dtype1:", x.dtype)
|
|
|
+ # print("Input dtype1:", x.dtype)
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
@@ -613,7 +615,6 @@ if __name__ == '__main__':
|
|
|
print(cfg['model']['n_dyn_negl'])
|
|
|
# net = WirepointPredictor()
|
|
|
|
|
|
-
|
|
|
if torch.cuda.is_available():
|
|
|
device_name = "cuda"
|
|
|
torch.backends.cudnn.deterministic = True
|
|
@@ -624,18 +625,28 @@ if __name__ == '__main__':
|
|
|
|
|
|
device = torch.device(device_name)
|
|
|
|
|
|
-
|
|
|
- dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
- train_sampler = torch.utils.data.RandomSampler(dataset)
|
|
|
+ dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
|
|
|
+ train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
|
- data_loader = torch.utils.data.DataLoader(
|
|
|
- dataset, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
|
|
|
+ data_loader_train = torch.utils.data.DataLoader(
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
|
|
|
+ )
|
|
|
+
|
|
|
+ dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
+ val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
+ # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
+ val_collate_fn = utils.collate_fn_wirepoint
|
|
|
+ data_loader_val = torch.utils.data.DataLoader(
|
|
|
+ dataset_val, batch_sampler=val_batch_sampler, num_workers=4, collate_fn=val_collate_fn
|
|
|
)
|
|
|
+
|
|
|
model = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
+ writer = SummaryWriter(cfg['io']['logdir'])
|
|
|
|
|
|
|
|
|
def move_to_device(data, device):
|
|
@@ -649,27 +660,33 @@ if __name__ == '__main__':
|
|
|
return data # 对于非张量类型的数据不做任何改变
|
|
|
|
|
|
|
|
|
+ def writer_loss(writer, losses):
|
|
|
+ # 记录每个损失项到TensorBoard
|
|
|
+ for key, value in losses.items():
|
|
|
+ if isinstance(value, dict): # 如果value本身也是一个字典(例如'loss_wirepoint')
|
|
|
+ for subkey, subvalue in value['losses'][0].items():
|
|
|
+ writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
|
|
|
+ else:
|
|
|
+ writer.add_scalar(key, value.item(), epoch)
|
|
|
+
|
|
|
|
|
|
- for i in range(cfg['optim']['max_epoch']):
|
|
|
+ for epoch in range(cfg['optim']['max_epoch']):
|
|
|
model.train()
|
|
|
- # imgs, targets = next(iter(data_loader))
|
|
|
- # loss = model(imgs, targets)
|
|
|
- # print(loss)
|
|
|
- # losses = _loss(loss)
|
|
|
- # optimizer.zero_grad()
|
|
|
- # loss.backward()
|
|
|
- # optimizer.step()
|
|
|
-
|
|
|
- for imgs, targets in data_loader:
|
|
|
+
|
|
|
+ for imgs, targets in data_loader_train:
|
|
|
losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
- print(losses)
|
|
|
loss = _loss(losses)
|
|
|
print(loss)
|
|
|
- # 优化器优化模型
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
+ writer_loss(writer, losses)
|
|
|
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for imgs, targets in dataset_val:
|
|
|
+ pred = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|