123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362 |
- # 根据LCNN写的train 2025/2/7
- '''
- #!/usr/bin/env python3
- import datetime
- import glob
- import os
- import os.path as osp
- import platform
- import pprint
- import random
- import shlex
- import shutil
- import subprocess
- import sys
- import numpy as np
- import torch
- import torchvision
- import yaml
- import lcnn
- from lcnn.config import C, M
- from lcnn.datasets import WireframeDataset, collate
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- from torchvision.models import resnet50
- from models.line_detect.line_rcnn import linercnn_resnet50_fpn
- def main():
- # 训练配置参数
- config = {
- # 数据集配置
- 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
- 'config_file': 'config/wireframe.yaml', # 配置文件路径
- # GPU配置
- 'devices': '0', # 使用的GPU设备
- 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
- # 预训练模型路径
- # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
- }
- # 更新配置
- C.update(C.from_yaml(filename=config['config_file']))
- M.update(C.model)
- # 设置随机数种子
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- # 设备配置
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
- # 数据加载
- kwargs = {
- "collate_fn": collate,
- "num_workers": C.io.num_workers if os.name != "nt" else 0,
- "pin_memory": True,
- }
- train_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="train"),
- shuffle=True,
- batch_size=M.batch_size,
- **kwargs,
- )
- val_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="val"),
- shuffle=False,
- batch_size=M.batch_size_eval,
- **kwargs,
- )
- model = linercnn_resnet50_fpn().to(device)
- # 加载预训练权重
- try:
- # 加载模型权重
- checkpoint = torch.load(config['pretrained_model'], map_location=device)
- # 根据实际的检查点结构选择加载方式
- if 'model_state_dict' in checkpoint:
- # 如果是完整的检查点
- model.load_state_dict(checkpoint['model_state_dict'])
- elif 'state_dict' in checkpoint:
- # 如果是只有状态字典的检查点
- model.load_state_dict(checkpoint['state_dict'])
- else:
- # 直接加载权重字典
- model.load_state_dict(checkpoint)
- print("Successfully loaded pre-trained model weights.")
- except Exception as e:
- print(f"Error loading model weights: {e}")
- # 优化器配置
- if C.optim.name == "Adam":
- optim = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- amsgrad=C.optim.amsgrad,
- )
- elif C.optim.name == "SGD":
- optim = torch.optim.SGD(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- momentum=C.optim.momentum,
- )
- else:
- raise NotImplementedError
- # 输出目录
- outdir = osp.join(
- osp.expanduser(C.io.logdir),
- f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
- )
- os.makedirs(outdir, exist_ok=True)
- try:
- trainer = lcnn.trainer.Trainer(
- device=device,
- model=model,
- optimizer=optim,
- train_loader=train_loader,
- val_loader=val_loader,
- out=outdir,
- )
- print("Starting training...")
- trainer.train()
- print("Training completed.")
- except BaseException:
- if len(glob.glob(f"{outdir}/viz/*")) <= 1:
- shutil.rmtree(outdir)
- raise
- if __name__ == "__main__":
- main()
- '''
- import os
- from typing import Optional, Any
- import cv2
- import numpy as np
- import torch
- from models.config.config_tool import read_yaml
- from models.line_detect.dataset_LD import WirePointDataset
- from tools import utils
- from torch.utils.tensorboard import SummaryWriter
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from skimage import io
- from models.line_detect.line_rcnn import linercnn_resnet50_fpn
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def _loss(losses):
- total_loss = 0
- for i in losses.keys():
- if i != "loss_wirepoint":
- total_loss += losses[i]
- else:
- loss_labels = losses[i]["losses"]
- loss_labels_k = list(loss_labels[0].keys())
- for j, name in enumerate(loss_labels_k):
- loss = loss_labels[0][name].mean()
- total_loss += loss
- return total_loss
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def imshow(im):
- plt.close()
- plt.tight_layout()
- plt.imshow(im)
- plt.colorbar(sm, fraction=0.046)
- plt.xlim([0, im.shape[0]])
- plt.ylim([im.shape[0], 0])
- def _plot_samples(self, i, index, result, targets, prefix):
- fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
- img = io.imread(fn)
- imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
- def draw_vecl(lines, sline, juncs, junts, fn):
- imshow(img)
- if len(lines) > 0 and not (lines[0] == 0).all():
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
- if i > 0 and (lines[i] == lines[0]).all():
- break
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
- if not (juncs[0] == 0).all():
- for i, j in enumerate(juncs):
- if i > 0 and (i == juncs[0]).all():
- break
- plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
- if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
- for i, j in enumerate(junts):
- if i > 0 and (i == junts[0]).all():
- break
- plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
- plt.savefig(fn), plt.close()
- junc = targets[i]["junc"].cpu().numpy() * 4
- jtyp = targets[i]["jtyp"].cpu().numpy()
- juncs = junc[jtyp == 0]
- junts = junc[jtyp == 1]
- rjuncs = result["juncs"][i].cpu().numpy() * 4
- rjunts = None
- if "junts" in result:
- rjunts = result["junts"][i].cpu().numpy() * 4
- lpre = targets[i]["lpre"].cpu().numpy() * 4
- vecl_target = targets[i]["lpre_label"].cpu().numpy()
- vecl_result = result["lines"][i].cpu().numpy() * 4
- score = result["score"][i].cpu().numpy()
- lpre = lpre[vecl_target == 1]
- draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
- draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
- img = cv2.imread(f"{prefix}_vecl_a.jpg")
- img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
- self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
- if __name__ == '__main__':
- cfg = r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml'
- cfg = read_yaml(cfg)
- print(f'cfg:{cfg}')
- print(cfg['model']['n_dyn_negl'])
- # net = WirepointPredictor()
- dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
- train_sampler = torch.utils.data.RandomSampler(dataset_train)
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
- train_collate_fn = utils.collate_fn_wirepoint
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
- )
- dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
- val_sampler = torch.utils.data.RandomSampler(dataset_val)
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
- val_collate_fn = utils.collate_fn_wirepoint
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
- )
- model = linercnn_resnet50_fpn().to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
- writer = SummaryWriter(cfg['io']['logdir'])
- def move_to_device(data, device):
- if isinstance(data, (list, tuple)):
- return type(data)(move_to_device(item, device) for item in data)
- elif isinstance(data, dict):
- return {key: move_to_device(value, device) for key, value in data.items()}
- elif isinstance(data, torch.Tensor):
- return data.to(device)
- else:
- return data # 对于非张量类型的数据不做任何改变
- def writer_loss(writer, losses, epoch):
- try:
- for key, value in losses.items():
- if key == 'loss_wirepoint':
- # ?? wirepoint ??????
- for subdict in losses['loss_wirepoint']['losses']:
- for subkey, subvalue in subdict.items():
- # ?? .item() ?????
- writer.add_scalar(f'loss_wirepoint/{subkey}',
- subvalue.item() if hasattr(subvalue, 'item') else subvalue,
- epoch)
- elif isinstance(value, torch.Tensor):
- writer.add_scalar(key, value.item(), epoch)
- except Exception as e:
- print(f"TensorBoard logging error: {e}")
- for epoch in range(cfg['optim']['max_epoch']):
- print(f"epoch:{epoch}")
- model.train()
- for imgs, targets in data_loader_train:
- losses = model(move_to_device(imgs, device), move_to_device(targets, device))
- # print(type(losses))
- # print(losses)
- loss = _loss(losses)
- # print(loss)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- writer_loss(writer, losses, epoch)
- model.eval()
- with torch.no_grad():
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
- pred = model(move_to_device(imgs, device))
- # print(f"perd:{pred}")
- break
- # print(f"perd:{pred}")
- # if batch_idx == 0:
- # viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
- # H = pred["wires"]
- # _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
- # imgs, targets = next(iter(data_loader))
- #
- # model.train()
- # pred = model(imgs, targets)
- # print(f'pred:{pred}')
- # result, losses = model(imgs, targets)
- # print(f'result:{result}')
- # print(f'pred:{losses}')
|