浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
0139015731
共有 2 个文件被更改,包括 47 次插入30 次删除
  1. 3 3
      models/wirenet/head.py
  2. 44 27
      models/wirenet/wirepoint_rcnn.py

+ 3 - 3
models/wirenet/head.py

@@ -775,13 +775,13 @@ class RoIHeads(nn.Module):
 
     def has_wirepoint(self):
         if self.wirepoint_roi_pool is None:
-            # print(f'wirepoint_roi_pool is None')
+            print(f'wirepoint_roi_pool is None')
             return False
         if self.wirepoint_head is None:
-            # print(f'wirepoint_head is None')
+            print(f'wirepoint_head is None')
             return False
         if self.wirepoint_predictor is None:
-            # print(f'wirepoint_roi_predictor is None')
+            print(f'wirepoint_roi_predictor is None')
             return False
         return True
 

+ 44 - 27
models/wirenet/wirepoint_rcnn.py

@@ -26,6 +26,8 @@ from models.wirenet.head import RoIHeads
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from tools import utils
 
+from torch.utils.tensorboard import SummaryWriter
+
 FEATURE_DIM = 8
 
 
@@ -290,7 +292,7 @@ class WirepointPredictor(nn.Module):
         # outputs=merge_features(outputs,100)
         batch, channel, row, col = inputs.shape
         print(f'outputs:{inputs.shape}')
-        print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
 
         if targets is not None:
             self.training = True
@@ -334,7 +336,7 @@ class WirepointPredictor(nn.Module):
         result = {}
         for stack, output in enumerate([inputs]):
             output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
             jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
             lmap = output[offset[0]: offset[1]].squeeze(0)
             joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
@@ -353,11 +355,11 @@ class WirepointPredictor(nn.Module):
         print(f'features shape:{features.shape}')
         x = self.fc1(features)
 
-        print(f'x:{x.shape}')
+        # print(f'x:{x.shape}')
 
         n_batch, n_channel, row, col = x.shape
 
-        print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
 
         xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
 
@@ -365,7 +367,7 @@ class WirepointPredictor(nn.Module):
             p, label, feat, jc = self.sample_lines(
                 meta, h["jmap"][i], h["joff"][i],
             )
-            print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
             ys.append(label)
             if self.training and self.do_static_sampling:
                 p = torch.cat([p, meta["lpre"]])
@@ -398,10 +400,10 @@ class WirepointPredictor(nn.Module):
                     .permute(1, 0, 2)
             )
             xp = self.pooling(xp)
-            print(f'xp.shape:{xp.shape}')
+            # print(f'xp.shape:{xp.shape}')
             xs.append(xp)
             idx.append(idx[-1] + xp.shape[0])
-            print(f'idx__:{idx}')
+            # print(f'idx__:{idx}')
 
         x, y = torch.cat(xs), torch.cat(ys)
         f = torch.cat(fs)
@@ -409,9 +411,9 @@ class WirepointPredictor(nn.Module):
 
         # print("Weight dtype:", self.fc2.weight.dtype)
         x = torch.cat([x, f], 1)
-        print("Input dtype:", x.dtype)
+        # print("Input dtype:", x.dtype)
         x = x.to(dtype=torch.float32)
-        print("Input dtype1:", x.dtype)
+        # print("Input dtype1:", x.dtype)
         x = self.fc2(x).flatten()
 
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
@@ -613,7 +615,6 @@ if __name__ == '__main__':
     print(cfg['model']['n_dyn_negl'])
     # net = WirepointPredictor()
 
-
     if torch.cuda.is_available():
         device_name = "cuda"
         torch.backends.cudnn.deterministic = True
@@ -624,18 +625,28 @@ if __name__ == '__main__':
 
     device = torch.device(device_name)
 
-
-    dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
-    train_sampler = torch.utils.data.RandomSampler(dataset)
+    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+    train_sampler = torch.utils.data.RandomSampler(dataset_train)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
     train_collate_fn = utils.collate_fn_wirepoint
-    data_loader = torch.utils.data.DataLoader(
-        dataset, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
+    )
+
+    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    val_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    val_collate_fn = utils.collate_fn_wirepoint
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=4, collate_fn=val_collate_fn
     )
+    
     model = wirepointrcnn_resnet50_fpn().to(device)
 
     optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
+    writer = SummaryWriter(cfg['io']['logdir'])
 
 
     def move_to_device(data, device):
@@ -649,27 +660,33 @@ if __name__ == '__main__':
             return data  # 对于非张量类型的数据不做任何改变
 
 
+    def writer_loss(writer, losses):
+        # 记录每个损失项到TensorBoard
+        for key, value in losses.items():
+            if isinstance(value, dict):  # 如果value本身也是一个字典(例如'loss_wirepoint')
+                for subkey, subvalue in value['losses'][0].items():
+                    writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
+            else:
+                writer.add_scalar(key, value.item(), epoch)
+
 
-    for i in range(cfg['optim']['max_epoch']):
+    for epoch in range(cfg['optim']['max_epoch']):
         model.train()
-        # imgs, targets = next(iter(data_loader))
-        # loss = model(imgs, targets)
-        # print(loss)
-        # losses = _loss(loss)
-        # optimizer.zero_grad()
-        # loss.backward()
-        # optimizer.step()
-
-        for imgs, targets in data_loader:
+
+        for imgs, targets in data_loader_train:
             losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-            print(losses)
             loss = _loss(losses)
             print(loss)
-            # 优化器优化模型
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
+            writer_loss(writer, losses)
 
+        model.eval()
+        with torch.no_grad():
+            for imgs, targets in dataset_val:
+                pred = model(move_to_device(imgs, device), move_to_device(targets, device))
+                
 
 # imgs, targets = next(iter(data_loader))
 #