| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362 |
- import os
- import time
- from datetime import datetime
- import numpy as np
- import torch
- from matplotlib import pyplot as plt
- from torch.utils.tensorboard import SummaryWriter
- from libs.vision_libs.utils import draw_bounding_boxes
- from models.base.base_model import BaseModel
- from models.base.base_trainer import BaseTrainer
- from models.config.config_tool import read_yaml
- from models.line_net.dataset_LD import WirePointDataset
- from models.wirenet.postprocess import postprocess
- from tools import utils
- from torchvision import transforms
- import matplotlib as mpl
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def _loss(losses):
- total_loss = 0
- for i in losses.keys():
- if i != "loss_wirepoint":
- total_loss += losses[i]
- else:
- loss_labels = losses[i]["losses"]
- loss_labels_k = list(loss_labels[0].keys())
- for j, name in enumerate(loss_labels_k):
- loss = loss_labels[0][name].mean()
- total_loss += loss
- return total_loss
- def c(x):
- return sm.to_rgba(x)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- class Trainer(BaseTrainer):
- def __init__(self, model=None, **kwargs):
- super().__init__(model, device, **kwargs)
- self.model = model
- # print(f'kwargs:{kwargs}')
- self.init_params(**kwargs)
- def init_params(self, **kwargs):
- if kwargs != {}:
- print(f'train_params:{kwargs["train_params"]}')
- self.freeze_config = kwargs['train_params']['freeze_params']
- print(f'freeze_config:{self.freeze_config}')
- self.dataset_path = kwargs['io']['datadir']
- self.batch_size = kwargs['train_params']['batch_size']
- self.num_workers = kwargs['train_params']['num_workers']
- self.logdir = kwargs['io']['logdir']
- self.resume_from = kwargs['train_params']['resume_from']
- self.optim = ''
- self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
- self.wts_path = os.path.join(self.train_result_ptath, 'weights')
- self.tb_path = os.path.join(self.train_result_ptath, 'logs')
- self.writer = SummaryWriter(self.tb_path)
- self.last_model_path = os.path.join(self.wts_path, 'last.pth')
- self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
- self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
- self.max_epoch = kwargs['train_params']['max_epoch']
- def move_to_device(self, data, device):
- if isinstance(data, (list, tuple)):
- return type(data)(self.move_to_device(item, device) for item in data)
- elif isinstance(data, dict):
- return {key: self.move_to_device(value, device) for key, value in data.items()}
- elif isinstance(data, torch.Tensor):
- return data.to(device)
- else:
- return data # 对于非张量类型的数据不做任何改变
- def freeze_params(self, model):
- """根据配置冻结模型参数"""
- default_config = {
- 'backbone': True, # 冻结 backbone
- 'rpn': False, # 不冻结 rpn
- 'roi_heads': {
- 'box_head': False,
- 'box_predictor': False,
- 'line_head': False,
- 'line_predictor': {
- 'fc1': False,
- 'fc2': {
- '0': False,
- '2': False,
- '4': False
- }
- }
- }
- }
- # 更新默认配置
- default_config.update(self.freeze_config)
- config = default_config
- print("\n===== Parameter Freezing Configuration =====")
- for name, module in model.named_children():
- if name in config:
- if isinstance(config[name], bool):
- for param in module.parameters():
- param.requires_grad = not config[name]
- print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
- elif isinstance(config[name], dict):
- for subname, submodule in module.named_children():
- if subname in config[name]:
- if isinstance(config[name][subname], bool):
- for param in submodule.parameters():
- param.requires_grad = not config[name][subname]
- print(
- f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
- elif isinstance(config[name][subname], dict):
- for subsubname, subsubmodule in submodule.named_children():
- if subsubname in config[name][subname]:
- for param in subsubmodule.parameters():
- param.requires_grad = not config[name][subname][subsubname]
- print(
- f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
- # 打印参数统计
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\nTotal Parameters: {total_params:,}")
- print(f"Trainable Parameters: {trainable_params:,}")
- print(f"Frozen Parameters: {total_params - trainable_params:,}")
- def load_best_model(self, model, optimizer, save_path, device):
- if os.path.exists(save_path):
- checkpoint = torch.load(save_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- if optimizer is not None:
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
- print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
- else:
- print(f"No saved model found at {save_path}")
- return model, optimizer
- def writer_predict_result(self, img, result, epoch):
- img = img.cpu().detach()
- im = img.permute(1, 2, 0)
- self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
- colors="yellow", width=1)
- self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- # print(f'pred[1]:{pred[1]}')
- heatmaps = result[-2][0]
- print(f'heatmaps:{heatmaps.shape}')
- jmap = heatmaps[1: 2].cpu().detach()
- lmap = heatmaps[2: 3].cpu().detach()
- self.writer.add_image("z-jmap", jmap, epoch)
- self.writer.add_image("z-lmap", lmap, epoch)
- # plt.imshow(lmap)
- # plt.show()
- H = result[-1]['wires']
- # lines = H["lines"][0].cpu().numpy()
- lines=result[0]["lines"]
- scores =100
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # postprocess lines to remove overlapped lines
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- for i, t in enumerate([0]):
- plt.gca().set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.imshow(im)
- plt.tight_layout()
- fig = plt.gcf()
- fig.canvas.draw()
- width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
- tmp_img = fig.canvas.tostring_argb()
- tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
- tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
- img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
- # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
- # fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- img2 = transforms.ToTensor()(img_rgb)
- self.writer.add_image("z-output", img2, epoch)
- def writer_loss(self, losses, epoch, phase='train'):
- try:
- for key, value in losses.items():
- if key == 'loss_wirepoint':
- for subdict in losses['loss_wirepoint']['losses']:
- for subkey, subvalue in subdict.items():
- self.writer.add_scalar(f'{phase}/loss/{subkey}',
- subvalue.item() if hasattr(subvalue, 'item') else subvalue,
- epoch)
- elif isinstance(value, torch.Tensor):
- self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
- except Exception as e:
- print(f"TensorBoard logging error: {e}")
- def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
- cfg = read_yaml(cfg)
- # print(f'cfg:{cfg}')
- # self.freeze_config = freeze_config or {} # 更新冻结配置
- self.train(model, **cfg)
- def train(self, model, **kwargs):
- self.init_params(**kwargs)
- dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
- dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
- train_sampler = torch.utils.data.RandomSampler(dataset_train)
- val_sampler = torch.utils.data.RandomSampler(dataset_val)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
- train_collate_fn = utils.collate_fn
- val_collate_fn = utils.collate_fn
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
- )
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
- )
- model.to(device)
- optimizer = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=kwargs['train_params']['optim']['lr']
- )
- for epoch in range(self.max_epoch):
- print(f"train epoch:{epoch}")
- model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
- # ========== Validation ==========
- with torch.no_grad():
- model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
- if epoch==0:
- best_train_loss = epoch_train_loss
- best_val_loss = epoch_val_loss
- self.save_last_model(model,self.last_model_path, epoch, optimizer)
- best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
- best_train_loss,
- optimizer)
- best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
- optimizer)
- def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
- if phase == 'train':
- model.train()
- if phase == 'val':
- model.eval()
- total_loss = 0
- epoch_step = 0
- global_step = epoch * len(data_loader)
- for imgs, targets in data_loader:
- imgs = self.move_to_device(imgs, device)
- targets = self.move_to_device(targets, device)
- if phase== 'val':
- result,losses = model(imgs, targets)
- else:
- losses = model(imgs, targets)
- loss = _loss(losses)
- total_loss += loss.item()
- if phase == 'train':
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- self.writer_loss(losses, global_step, phase=phase)
- global_step += 1
- if epoch_step == 0 and phase == 'val':
- t_start = time.time()
- print(f'start to predict:{t_start}')
- result = model(self.move_to_device(imgs, self.device))
- t_end = time.time()
- print(f'predict used:{t_end - t_start}')
- self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
- epoch_step+=1
- avg_loss = total_loss / len(data_loader)
- print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
- self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
- return model, avg_loss
- def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- if current_loss <= best_loss:
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'loss': current_loss
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
- return current_loss
- return best_loss
- def save_last_model(self, model, save_path, epoch, optimizer=None):
- if os.path.exists(f'{self.wts_path}/last.pt'):
- os.remove(f'{self.wts_path}/last.pt')
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- if __name__ == '__main__':
- print('')
|