123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- #!/usr/bin/env python3
- import datetime
- import glob
- import os
- import os.path as osp
- import platform
- import pprint
- import random
- import shlex
- import shutil
- import subprocess
- import sys
- import numpy as np
- import torch
- import torchvision
- import yaml
- import lcnn
- from lcnn.config import C, M
- from lcnn.datasets import WireframeDataset, collate
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- from torchvision.models import resnet50
- from models.line_detect.line_rcnn import linercnn_resnet50_fpn
- def main():
- # 训练配置参数
- config = {
- # 数据集配置
- 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
- 'config_file': 'config/wireframe.yaml', # 配置文件路径
- # GPU配置
- 'devices': '0', # 使用的GPU设备
- 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
- # 预训练模型路径
- # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
- }
- # 更新配置
- C.update(C.from_yaml(filename=config['config_file']))
- M.update(C.model)
- # 设置随机数种子
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- # 设备配置
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
- # 数据加载
- kwargs = {
- "collate_fn": collate,
- "num_workers": C.io.num_workers if os.name != "nt" else 0,
- "pin_memory": True,
- }
- train_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="train"),
- shuffle=True,
- batch_size=M.batch_size,
- **kwargs,
- )
- val_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="val"),
- shuffle=False,
- batch_size=M.batch_size_eval,
- **kwargs,
- )
- model = linercnn_resnet50_fpn().to(device)
- # 加载预训练权重
- try:
- # 加载模型权重
- checkpoint = torch.load(config['pretrained_model'], map_location=device)
- # 根据实际的检查点结构选择加载方式
- if 'model_state_dict' in checkpoint:
- # 如果是完整的检查点
- model.load_state_dict(checkpoint['model_state_dict'])
- elif 'state_dict' in checkpoint:
- # 如果是只有状态字典的检查点
- model.load_state_dict(checkpoint['state_dict'])
- else:
- # 直接加载权重字典
- model.load_state_dict(checkpoint)
- print("Successfully loaded pre-trained model weights.")
- except Exception as e:
- print(f"Error loading model weights: {e}")
- # 优化器配置
- if C.optim.name == "Adam":
- optim = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- amsgrad=C.optim.amsgrad,
- )
- elif C.optim.name == "SGD":
- optim = torch.optim.SGD(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- momentum=C.optim.momentum,
- )
- else:
- raise NotImplementedError
- # 输出目录
- outdir = osp.join(
- osp.expanduser(C.io.logdir),
- f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
- )
- os.makedirs(outdir, exist_ok=True)
- try:
- trainer = lcnn.trainer.Trainer(
- device=device,
- model=model,
- optimizer=optim,
- train_loader=train_loader,
- val_loader=val_loader,
- out=outdir,
- )
- print("Starting training...")
- trainer.train()
- print("Training completed.")
- except BaseException:
- if len(glob.glob(f"{outdir}/viz/*")) <= 1:
- shutil.rmtree(outdir)
- raise
- if __name__ == "__main__":
- main()
|