| 
														
															@@ -1,6 +1,7 @@ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import torch 
														 | 
														
														 | 
														
															 import torch 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torch.utils.tensorboard import SummaryWriter 
														 | 
														
														 | 
														
															 from torch.utils.tensorboard import SummaryWriter 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from models.base.base_model import BaseModel 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from models.base.base_trainer import BaseTrainer 
														 | 
														
														 | 
														
															 from models.base.base_trainer import BaseTrainer 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from models.config.config_tool import read_yaml 
														 | 
														
														 | 
														
															 from models.config.config_tool import read_yaml 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from models.line_detect.dataset_LD import WirePointDataset 
														 | 
														
														 | 
														
															 from models.line_detect.dataset_LD import WirePointDataset 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -21,13 +22,23 @@ def _loss(losses): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         total_loss += loss 
														 | 
														
														 | 
														
															         total_loss += loss 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     return total_loss 
														 | 
														
														 | 
														
															     return total_loss 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def move_to_device(data, device): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if isinstance(data, (list, tuple)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return type(data)(move_to_device(item, device) for item in data) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    elif isinstance(data, dict): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return {key: move_to_device(value, device) for key, value in data.items()} 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    elif isinstance(data, torch.Tensor): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return data.to(device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    else: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return data  # 对于非张量类型的数据不做任何改变 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 class Trainer(BaseTrainer): 
														 | 
														
														 | 
														
															 class Trainer(BaseTrainer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def __init__(self, model=None, 
														 | 
														
														 | 
														
															     def __init__(self, model=None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                  dataset=None, 
														 | 
														
														 | 
														
															                  dataset=None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                  device='cuda', 
														 | 
														
														 | 
														
															                  device='cuda', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                  **kwargs): 
														 | 
														
														 | 
														
															                  **kwargs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         super().__init__(model,dataset,device,**kwargs) 
														 | 
														
														 | 
														
															         super().__init__(model,dataset,device,**kwargs) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def move_to_device(self, data, device): 
														 | 
														
														 | 
														
															     def move_to_device(self, data, device): 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -54,15 +65,15 @@ class Trainer(BaseTrainer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         except Exception as e: 
														 | 
														
														 | 
														
															         except Exception as e: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             print(f"TensorBoard logging error: {e}") 
														 | 
														
														 | 
														
															             print(f"TensorBoard logging error: {e}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    def train_cfg(self, model, cfg): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def train_cfg(self, model:BaseModel, cfg): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         # cfg = r'./config/wireframe.yaml' 
														 | 
														
														 | 
														
															         # cfg = r'./config/wireframe.yaml' 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         cfg = read_yaml(cfg) 
														 | 
														
														 | 
														
															         cfg = read_yaml(cfg) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         print(f'cfg:{cfg}') 
														 | 
														
														 | 
														
															         print(f'cfg:{cfg}') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        print(cfg['model']['n_dyn_negl']) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # print(cfg['n_dyn_negl']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self.train(model, **cfg) 
														 | 
														
														 | 
														
															         self.train(model, **cfg) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    def train(self, model, **cfg): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def train(self, model, **kwargs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         train_sampler = torch.utils.data.RandomSampler(dataset_train) 
														 | 
														
														 | 
														
															         train_sampler = torch.utils.data.RandomSampler(dataset_train) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         # test_sampler = torch.utils.data.SequentialSampler(dataset_test) 
														 | 
														
														 | 
														
															         # test_sampler = torch.utils.data.SequentialSampler(dataset_test) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True) 
														 | 
														
														 | 
														
															         train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -71,7 +82,7 @@ class Trainer(BaseTrainer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn 
														 | 
														
														 | 
														
															             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ) 
														 | 
														
														 | 
														
															         ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         val_sampler = torch.utils.data.RandomSampler(dataset_val) 
														 | 
														
														 | 
														
															         val_sampler = torch.utils.data.RandomSampler(dataset_val) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         # test_sampler = torch.utils.data.SequentialSampler(dataset_test) 
														 | 
														
														 | 
														
															         # test_sampler = torch.utils.data.SequentialSampler(dataset_test) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True) 
														 | 
														
														 | 
														
															         val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -82,15 +93,19 @@ class Trainer(BaseTrainer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         # model = linenet_resnet50_fpn().to(self.device) 
														 | 
														
														 | 
														
															         # model = linenet_resnet50_fpn().to(self.device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        writer = SummaryWriter(cfg['io']['logdir']) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        writer = SummaryWriter(kwargs['io']['logdir']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        for epoch in range(cfg['optim']['max_epoch']): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        for epoch in range(kwargs['optim']['max_epoch']): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             print(f"epoch:{epoch}") 
														 | 
														
														 | 
														
															             print(f"epoch:{epoch}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             model.train() 
														 | 
														
														 | 
														
															             model.train() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															             for imgs, targets in data_loader_train: 
														 | 
														
														 | 
														
															             for imgs, targets in data_loader_train: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                losses = model(self.move_to_device(imgs, self.device), self.move_to_device(targets, self.device)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                imgs = move_to_device(imgs, device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                targets=move_to_device(targets,device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                print(f'imgs:{len(imgs)}') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                print(f'targets:{len(targets)}') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                losses = model(imgs, targets) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 # print(losses) 
														 | 
														
														 | 
														
															                 # print(losses) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 loss = _loss(losses) 
														 | 
														
														 | 
														
															                 loss = _loss(losses) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 optimizer.zero_grad() 
														 | 
														
														 | 
														
															                 optimizer.zero_grad() 
														 |