|
|
@@ -1,18 +1,20 @@
|
|
|
-
|
|
|
import os
|
|
|
import time
|
|
|
from datetime import datetime
|
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
+from matplotlib import pyplot as plt
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
+from libs.vision_libs.utils import draw_bounding_boxes
|
|
|
from models.base.base_model import BaseModel
|
|
|
from models.base.base_trainer import BaseTrainer
|
|
|
from models.config.config_tool import read_yaml
|
|
|
from models.line_detect.dataset_LD import WirePointDataset
|
|
|
-from models.line_detect.postprocess import box_line_, show_
|
|
|
-from utils.log_util import show_line, save_last_model, save_best_model
|
|
|
+from models.wirenet.postprocess import postprocess
|
|
|
from tools import utils
|
|
|
+from torchvision import transforms
|
|
|
|
|
|
|
|
|
def _loss(losses):
|
|
|
@@ -26,26 +28,38 @@ def _loss(losses):
|
|
|
for j, name in enumerate(loss_labels_k):
|
|
|
loss = loss_labels[0][name].mean()
|
|
|
total_loss += loss
|
|
|
-
|
|
|
return total_loss
|
|
|
+
|
|
|
+
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
-def move_to_device(data, device):
|
|
|
- if isinstance(data, (list, tuple)):
|
|
|
- return type(data)(move_to_device(item, device) for item in data)
|
|
|
- elif isinstance(data, dict):
|
|
|
- return {key: move_to_device(value, device) for key, value in data.items()}
|
|
|
- elif isinstance(data, torch.Tensor):
|
|
|
- return data.to(device)
|
|
|
- else:
|
|
|
- return data # 对于非张量类型的数据不做任何改变
|
|
|
+
|
|
|
|
|
|
class Trainer(BaseTrainer):
|
|
|
- def __init__(self, model=None,
|
|
|
- dataset=None,
|
|
|
- device='cuda',
|
|
|
- **kwargs):
|
|
|
+ def __init__(self, model=None, **kwargs):
|
|
|
+ super().__init__(model, device, **kwargs)
|
|
|
+ self.model = model
|
|
|
+ print(f'kwargs:{kwargs}')
|
|
|
+ self.init_params(**kwargs)
|
|
|
|
|
|
- super().__init__(model,dataset,device,**kwargs)
|
|
|
+ def init_params(self, **kwargs):
|
|
|
+ if kwargs != {}:
|
|
|
+ print(f'train_params:{kwargs["train_params"]}')
|
|
|
+ self.freeze_config = kwargs['train_params']['freeze_params']
|
|
|
+ print(f'freeze_config:{self.freeze_config}')
|
|
|
+ self.dataset_path = kwargs['io']['datadir']
|
|
|
+ self.batch_size = kwargs['train_params']['batch_size']
|
|
|
+ self.num_workers = kwargs['train_params']['num_workers']
|
|
|
+ self.logdir = kwargs['io']['logdir']
|
|
|
+ self.resume_from = kwargs['train_params']['resume_from']
|
|
|
+ self.optim = ''
|
|
|
+ self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
|
+ self.wts_path = os.path.join(self.train_result_ptath, 'weights')
|
|
|
+ self.tb_path = os.path.join(self.train_result_ptath, 'logs')
|
|
|
+ self.writer = SummaryWriter(self.tb_path)
|
|
|
+ self.last_model_path = os.path.join(self.wts_path, 'last.pth')
|
|
|
+ self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
|
|
|
+ self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
|
|
|
+ self.max_epoch = kwargs['train_params']['max_epoch']
|
|
|
|
|
|
def move_to_device(self, data, device):
|
|
|
if isinstance(data, (list, tuple)):
|
|
|
@@ -57,7 +71,63 @@ class Trainer(BaseTrainer):
|
|
|
else:
|
|
|
return data # 对于非张量类型的数据不做任何改变
|
|
|
|
|
|
- def load_best_model(self,model, optimizer, save_path, device):
|
|
|
+ def freeze_params(self, model):
|
|
|
+ """根据配置冻结模型参数"""
|
|
|
+ default_config = {
|
|
|
+ 'backbone': True, # 冻结 backbone
|
|
|
+ 'rpn': False, # 不冻结 rpn
|
|
|
+ 'roi_heads': {
|
|
|
+ 'box_head': False,
|
|
|
+ 'box_predictor': False,
|
|
|
+ 'line_head': False,
|
|
|
+ 'line_predictor': {
|
|
|
+ 'fc1': False,
|
|
|
+ 'fc2': {
|
|
|
+ '0': False,
|
|
|
+ '2': False,
|
|
|
+ '4': False
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ # 更新默认配置
|
|
|
+ default_config.update(self.freeze_config)
|
|
|
+ config = default_config
|
|
|
+
|
|
|
+ print("\n===== Parameter Freezing Configuration =====")
|
|
|
+ for name, module in model.named_children():
|
|
|
+ if name in config:
|
|
|
+ if isinstance(config[name], bool):
|
|
|
+ for param in module.parameters():
|
|
|
+ param.requires_grad = not config[name]
|
|
|
+ print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
|
|
|
+
|
|
|
+ elif isinstance(config[name], dict):
|
|
|
+ for subname, submodule in module.named_children():
|
|
|
+ if subname in config[name]:
|
|
|
+ if isinstance(config[name][subname], bool):
|
|
|
+ for param in submodule.parameters():
|
|
|
+ param.requires_grad = not config[name][subname]
|
|
|
+ print(
|
|
|
+ f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
|
|
|
+
|
|
|
+ elif isinstance(config[name][subname], dict):
|
|
|
+ for subsubname, subsubmodule in submodule.named_children():
|
|
|
+ if subsubname in config[name][subname]:
|
|
|
+ for param in subsubmodule.parameters():
|
|
|
+ param.requires_grad = not config[name][subname][subsubname]
|
|
|
+ print(
|
|
|
+ f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
|
|
|
+
|
|
|
+ # 打印参数统计
|
|
|
+ total_params = sum(p.numel() for p in model.parameters())
|
|
|
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+ print(f"\nTotal Parameters: {total_params:,}")
|
|
|
+ print(f"Trainable Parameters: {trainable_params:,}")
|
|
|
+ print(f"Frozen Parameters: {total_params - trainable_params:,}")
|
|
|
+
|
|
|
+ def load_best_model(self, model, optimizer, save_path, device):
|
|
|
if os.path.exists(save_path):
|
|
|
checkpoint = torch.load(save_path, map_location=device)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
@@ -70,111 +140,206 @@ class Trainer(BaseTrainer):
|
|
|
print(f"No saved model found at {save_path}")
|
|
|
return model, optimizer
|
|
|
|
|
|
- def writer_loss(self, writer, losses, epoch):
|
|
|
+ def writer_predict_result(self, img, result, epoch):
|
|
|
+ img = img.cpu().detach()
|
|
|
+ im = img.permute(1, 2, 0)
|
|
|
+ self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+ self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
+ # print(f'pred[1]:{pred[1]}')
|
|
|
+ heatmaps = result[-2][0]
|
|
|
+ print(f'heatmaps:{heatmaps.shape}')
|
|
|
+ jmap = heatmaps[1: 2].cpu().detach()
|
|
|
+ lmap = heatmaps[2: 3].cpu().detach()
|
|
|
+ self.writer.add_image("z-jmap", jmap, epoch)
|
|
|
+ self.writer.add_image("z-lmap", lmap, epoch)
|
|
|
+ # plt.imshow(lmap)
|
|
|
+ # plt.show()
|
|
|
+ H = result[-1]['wires']
|
|
|
+ lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
+ scores = H["score"][0].cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for i, t in enumerate([0]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t:
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.tight_layout()
|
|
|
+ fig = plt.gcf()
|
|
|
+ fig.canvas.draw()
|
|
|
+
|
|
|
+ width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
|
|
|
+ tmp_img = fig.canvas.tostring_argb()
|
|
|
+ tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
|
|
|
+ tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
|
|
|
+
|
|
|
+ img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
|
|
|
+
|
|
|
+ # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
|
|
|
+ # fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ img2 = transforms.ToTensor()(img_rgb)
|
|
|
+
|
|
|
+ self.writer.add_image("z-output", img2, epoch)
|
|
|
+
|
|
|
+ def writer_loss(self, losses, epoch, phase='train'):
|
|
|
try:
|
|
|
for key, value in losses.items():
|
|
|
if key == 'loss_wirepoint':
|
|
|
for subdict in losses['loss_wirepoint']['losses']:
|
|
|
for subkey, subvalue in subdict.items():
|
|
|
- writer.add_scalar(f'loss/{subkey}',
|
|
|
- subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
- epoch)
|
|
|
+ self.writer.add_scalar(f'{phase}/loss/{subkey}',
|
|
|
+ subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
+ epoch)
|
|
|
elif isinstance(value, torch.Tensor):
|
|
|
- writer.add_scalar(f'loss/{key}', value.item(), epoch)
|
|
|
+ self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
|
|
|
except Exception as e:
|
|
|
print(f"TensorBoard logging error: {e}")
|
|
|
|
|
|
- def train_cfg(self, model:BaseModel, cfg):
|
|
|
- # cfg = r'./config/wireframe.yaml'
|
|
|
+ def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
|
|
|
cfg = read_yaml(cfg)
|
|
|
- print(f'cfg:{cfg}')
|
|
|
- # print(cfg['n_dyn_negl'])
|
|
|
+ # print(f'cfg:{cfg}')
|
|
|
+ # self.freeze_config = freeze_config or {} # 更新冻结配置
|
|
|
+
|
|
|
self.train(model, **cfg)
|
|
|
|
|
|
def train(self, model, **kwargs):
|
|
|
- dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
|
+
|
|
|
+ self.init_params(**kwargs)
|
|
|
+
|
|
|
+ dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
|
|
|
+ dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
|
|
|
+
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
|
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
- train_collate_fn = utils.collate_fn_wirepoint
|
|
|
+ val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
|
|
|
+ train_collate_fn = utils.collate_fn
|
|
|
+ val_collate_fn = utils.collate_fn
|
|
|
+
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
|
- dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
|
|
|
)
|
|
|
-
|
|
|
- dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
|
- val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
|
|
|
- val_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
|
- dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
|
+ dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
|
|
|
)
|
|
|
|
|
|
- train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
|
- wts_path = os.path.join(train_result_ptath, 'weights')
|
|
|
- tb_path = os.path.join(train_result_ptath, 'logs')
|
|
|
- writer = SummaryWriter(tb_path)
|
|
|
-
|
|
|
- optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
|
|
|
- # writer = SummaryWriter(kwargs['io']['logdir'])
|
|
|
model.to(device)
|
|
|
|
|
|
+ optimizer = torch.optim.Adam(
|
|
|
+ filter(lambda p: p.requires_grad, model.parameters()),
|
|
|
+ lr=kwargs['train_params']['optim']['lr']
|
|
|
+ )
|
|
|
|
|
|
+ for epoch in range(self.max_epoch):
|
|
|
+ print(f"train epoch:{epoch}")
|
|
|
|
|
|
- # # 加载权重
|
|
|
- # save_path = 'logs/pth/best_model.pth'
|
|
|
- # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
|
|
|
+ model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
|
|
|
|
|
|
- # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
|
|
|
- # os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
|
|
|
- last_model_path = os.path.join(wts_path, 'last.pth')
|
|
|
- best_model_path = os.path.join(wts_path, 'best.pth')
|
|
|
- global_step = 0
|
|
|
+ # ========== Validation ==========
|
|
|
+ with torch.no_grad():
|
|
|
+ model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val', )
|
|
|
|
|
|
- for epoch in range(kwargs['optim']['max_epoch']):
|
|
|
- print(f"epoch:{epoch}")
|
|
|
- total_train_loss = 0.0
|
|
|
+ self.save_last_model(model, epoch, optimizer)
|
|
|
+ best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
|
|
|
+ best_train_loss,
|
|
|
+ optimizer)
|
|
|
+ best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
|
|
|
+ optimizer)
|
|
|
|
|
|
+ def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
|
|
|
+ if phase == 'train':
|
|
|
model.train()
|
|
|
+ if phase == 'val':
|
|
|
+ model.eval
|
|
|
|
|
|
- for imgs, targets in data_loader_train:
|
|
|
- imgs = move_to_device(imgs, device)
|
|
|
- targets=move_to_device(targets,device)
|
|
|
- # print(f'imgs:{len(imgs)}')
|
|
|
- # print(f'targets:{len(targets)}')
|
|
|
- losses = model(imgs, targets)
|
|
|
- loss = _loss(losses)
|
|
|
- total_train_loss += loss.item()
|
|
|
+ total_loss = 0
|
|
|
+ epoch_step = 0
|
|
|
+ global_step = epoch_step * len(data_loader)
|
|
|
+ for imgs, targets in data_loader:
|
|
|
+ imgs = self.move_to_device(imgs, device)
|
|
|
+ targets = self.move_to_device(targets, device)
|
|
|
+ losses = model(imgs, targets)
|
|
|
+ loss = _loss(losses)
|
|
|
+ total_loss += loss.item()
|
|
|
+ if phase == 'train':
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
- self.writer_loss(writer, losses, global_step)
|
|
|
- global_step+=1
|
|
|
+ self.writer_loss(losses, global_step, phase=phase)
|
|
|
+ global_step += 1
|
|
|
|
|
|
+ if epoch_step == 0 and phase == 'val':
|
|
|
+ t_start = time.time()
|
|
|
+ print(f'start to predict:{t_start}')
|
|
|
+ result = model(self.move_to_device(imgs, self.device))
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'predict used:{t_end - t_start}')
|
|
|
+ self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
|
|
|
|
|
|
- avg_train_loss = total_train_loss / len(data_loader_train)
|
|
|
- if epoch == 0:
|
|
|
- best_loss = avg_train_loss;
|
|
|
+ avg_loss = total_loss / len(data_loader)
|
|
|
+ print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
|
|
|
+ self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
|
|
|
+ return model, avg_loss
|
|
|
|
|
|
- writer.add_scalar('loss/train', avg_train_loss, epoch)
|
|
|
+ def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
+ if current_loss <= best_loss:
|
|
|
+ checkpoint = {
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ 'loss': current_loss
|
|
|
+ }
|
|
|
+ if optimizer is not None:
|
|
|
+ checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
|
|
|
|
|
- if os.path.exists(f'{wts_path}/last.pt'):
|
|
|
- os.remove(f'{wts_path}/last.pt')
|
|
|
- # torch.save(model.state_dict(), f'{wts_path}/last.pt')
|
|
|
- save_last_model(model,last_model_path,epoch,optimizer)
|
|
|
- best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
|
|
|
+ torch.save(checkpoint, save_path)
|
|
|
+ print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
|
|
|
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
- t_start = time.time()
|
|
|
- print(f'start to predict:{t_start}')
|
|
|
- pred = model(self.move_to_device(imgs, self.device))
|
|
|
- t_end = time.time()
|
|
|
- print(f'predict used:{t_end - t_start}')
|
|
|
- if batch_idx == 0:
|
|
|
- show_line(imgs[0], pred, epoch, writer)
|
|
|
- break
|
|
|
+ return current_loss
|
|
|
+
|
|
|
+ return best_loss
|
|
|
+
|
|
|
+ def save_last_model(self, model, save_path, epoch, optimizer=None):
|
|
|
+
|
|
|
+ if os.path.exists(f'{self.wts_path}/last.pt'):
|
|
|
+ os.remove(f'{self.wts_path}/last.pt')
|
|
|
+
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
+
|
|
|
+ checkpoint = {
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ }
|
|
|
+
|
|
|
+ if optimizer is not None:
|
|
|
+ checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
|
|
+
|
|
|
+ torch.save(checkpoint, save_path)
|
|
|
|
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
+ print('')
|