trainer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. from torch.utils.tensorboard import SummaryWriter
  3. from models.base.base_model import BaseModel
  4. from models.base.base_trainer import BaseTrainer
  5. from models.config.config_tool import read_yaml
  6. from models.line_detect.dataset_LD import WirePointDataset
  7. from utils.log_util import show_line
  8. from tools import utils
  9. def _loss(losses):
  10. total_loss = 0
  11. for i in losses.keys():
  12. if i != "loss_wirepoint":
  13. total_loss += losses[i]
  14. else:
  15. loss_labels = losses[i]["losses"]
  16. loss_labels_k = list(loss_labels[0].keys())
  17. for j, name in enumerate(loss_labels_k):
  18. loss = loss_labels[0][name].mean()
  19. total_loss += loss
  20. return total_loss
  21. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  22. def move_to_device(data, device):
  23. if isinstance(data, (list, tuple)):
  24. return type(data)(move_to_device(item, device) for item in data)
  25. elif isinstance(data, dict):
  26. return {key: move_to_device(value, device) for key, value in data.items()}
  27. elif isinstance(data, torch.Tensor):
  28. return data.to(device)
  29. else:
  30. return data # 对于非张量类型的数据不做任何改变
  31. class Trainer(BaseTrainer):
  32. def __init__(self, model=None,
  33. dataset=None,
  34. device='cuda',
  35. **kwargs):
  36. super().__init__(model,dataset,device,**kwargs)
  37. def move_to_device(self, data, device):
  38. if isinstance(data, (list, tuple)):
  39. return type(data)(self.move_to_device(item, device) for item in data)
  40. elif isinstance(data, dict):
  41. return {key: self.move_to_device(value, device) for key, value in data.items()}
  42. elif isinstance(data, torch.Tensor):
  43. return data.to(device)
  44. else:
  45. return data # 对于非张量类型的数据不做任何改变
  46. def writer_loss(self, writer, losses, epoch):
  47. try:
  48. for key, value in losses.items():
  49. if key == 'loss_wirepoint':
  50. for subdict in losses['loss_wirepoint']['losses']:
  51. for subkey, subvalue in subdict.items():
  52. writer.add_scalar(f'loss/{subkey}',
  53. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  54. epoch)
  55. elif isinstance(value, torch.Tensor):
  56. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  57. except Exception as e:
  58. print(f"TensorBoard logging error: {e}")
  59. def train_cfg(self, model:BaseModel, cfg):
  60. # cfg = r'./config/wireframe.yaml'
  61. cfg = read_yaml(cfg)
  62. print(f'cfg:{cfg}')
  63. # print(cfg['n_dyn_negl'])
  64. self.train(model, **cfg)
  65. def train(self, model, **kwargs):
  66. dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  67. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  68. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  69. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  70. train_collate_fn = utils.collate_fn_wirepoint
  71. data_loader_train = torch.utils.data.DataLoader(
  72. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  73. )
  74. dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  75. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  76. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  77. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  78. val_collate_fn = utils.collate_fn_wirepoint
  79. data_loader_val = torch.utils.data.DataLoader(
  80. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  81. )
  82. # model = linenet_resnet50_fpn().to(self.device)
  83. optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
  84. writer = SummaryWriter(kwargs['io']['logdir'])
  85. model.to(device)
  86. for epoch in range(kwargs['optim']['max_epoch']):
  87. print(f"epoch:{epoch}")
  88. model.train()
  89. for imgs, targets in data_loader_train:
  90. imgs = move_to_device(imgs, device)
  91. targets=move_to_device(targets,device)
  92. print(f'imgs:{len(imgs)}')
  93. print(f'targets:{len(targets)}')
  94. losses = model(imgs, targets)
  95. # print(losses)
  96. loss = _loss(losses)
  97. optimizer.zero_grad()
  98. loss.backward()
  99. optimizer.step()
  100. self.writer_loss(writer, losses, epoch)
  101. model.eval()
  102. with torch.no_grad():
  103. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  104. pred = model(self.move_to_device(imgs, self.device))
  105. if batch_idx == 0:
  106. show_line(imgs[0], pred, epoch, writer)
  107. break