trainer.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import torch
  2. from torch.utils.tensorboard import SummaryWriter
  3. from models.base.base_trainer import BaseTrainer
  4. from models.config.config_tool import read_yaml
  5. from models.line_detect.dataset_LD import WirePointDataset
  6. from utils.log_util import show_line
  7. from tools import utils
  8. def _loss(losses):
  9. total_loss = 0
  10. for i in losses.keys():
  11. if i != "loss_wirepoint":
  12. total_loss += losses[i]
  13. else:
  14. loss_labels = losses[i]["losses"]
  15. loss_labels_k = list(loss_labels[0].keys())
  16. for j, name in enumerate(loss_labels_k):
  17. loss = loss_labels[0][name].mean()
  18. total_loss += loss
  19. return total_loss
  20. class Trainer(BaseTrainer):
  21. def __init__(self, model=None,
  22. dataset=None,
  23. device='cuda',
  24. **kwargs):
  25. super().__init__(model,dataset,device,**kwargs)
  26. def move_to_device(self, data, device):
  27. if isinstance(data, (list, tuple)):
  28. return type(data)(self.move_to_device(item, device) for item in data)
  29. elif isinstance(data, dict):
  30. return {key: self.move_to_device(value, device) for key, value in data.items()}
  31. elif isinstance(data, torch.Tensor):
  32. return data.to(device)
  33. else:
  34. return data # 对于非张量类型的数据不做任何改变
  35. def writer_loss(self, writer, losses, epoch):
  36. try:
  37. for key, value in losses.items():
  38. if key == 'loss_wirepoint':
  39. for subdict in losses['loss_wirepoint']['losses']:
  40. for subkey, subvalue in subdict.items():
  41. writer.add_scalar(f'loss/{subkey}',
  42. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  43. epoch)
  44. elif isinstance(value, torch.Tensor):
  45. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  46. except Exception as e:
  47. print(f"TensorBoard logging error: {e}")
  48. def train_cfg(self, model, cfg):
  49. # cfg = r'./config/wireframe.yaml'
  50. cfg = read_yaml(cfg)
  51. print(f'cfg:{cfg}')
  52. print(cfg['model']['n_dyn_negl'])
  53. self.train(model, **cfg)
  54. def train(self, model, **cfg):
  55. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  56. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  57. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  58. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  59. train_collate_fn = utils.collate_fn_wirepoint
  60. data_loader_train = torch.utils.data.DataLoader(
  61. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  62. )
  63. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  64. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  65. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  66. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  67. val_collate_fn = utils.collate_fn_wirepoint
  68. data_loader_val = torch.utils.data.DataLoader(
  69. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  70. )
  71. # model = linenet_resnet50_fpn().to(self.device)
  72. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  73. writer = SummaryWriter(cfg['io']['logdir'])
  74. for epoch in range(cfg['optim']['max_epoch']):
  75. print(f"epoch:{epoch}")
  76. model.train()
  77. for imgs, targets in data_loader_train:
  78. losses = model(self.move_to_device(imgs, self.device), self.move_to_device(targets, self.device))
  79. # print(losses)
  80. loss = _loss(losses)
  81. optimizer.zero_grad()
  82. loss.backward()
  83. optimizer.step()
  84. self.writer_loss(writer, losses, epoch)
  85. model.eval()
  86. with torch.no_grad():
  87. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  88. pred = model(self.move_to_device(imgs, self.device))
  89. if batch_idx == 0:
  90. show_line(imgs[0], pred, epoch, writer)
  91. break