#!/usr/bin/env python3 import datetime import glob import os import os.path as osp import platform import pprint import random import shlex import shutil import subprocess import sys import numpy as np import torch import torchvision import yaml import lcnn from lcnn.config import C, M from lcnn.datasets import WireframeDataset, collate from lcnn.models.line_vectorizer import LineVectorizer from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner from torchvision.models import resnet50 from models.line_detect.line_rcnn import linercnn_resnet50_fpn def main(): # 训练配置参数 config = { # 数据集配置 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录 'config_file': 'config/wireframe.yaml', # 配置文件路径 # GPU配置 'devices': '0', # 使用的GPU设备 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet # 预训练模型路径 # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径 } # 更新配置 C.update(C.from_yaml(filename=config['config_file'])) M.update(C.model) # 设置随机数种子 random.seed(0) np.random.seed(0) torch.manual_seed(0) # 设备配置 device_name = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = config['devices'] if torch.cuda.is_available(): device_name = "cuda" torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(0) print("Let's use", torch.cuda.device_count(), "GPU(s)!") else: print("CUDA is not available") device = torch.device(device_name) # 数据加载 kwargs = { "collate_fn": collate, "num_workers": C.io.num_workers if os.name != "nt" else 0, "pin_memory": True, } train_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="train"), shuffle=True, batch_size=M.batch_size, **kwargs, ) val_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="val"), shuffle=False, batch_size=M.batch_size_eval, **kwargs, ) model = linercnn_resnet50_fpn().to(device) # 加载预训练权重 try: # 加载模型权重 checkpoint = torch.load(config['pretrained_model'], map_location=device) # 根据实际的检查点结构选择加载方式 if 'model_state_dict' in checkpoint: # 如果是完整的检查点 model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: # 如果是只有状态字典的检查点 model.load_state_dict(checkpoint['state_dict']) else: # 直接加载权重字典 model.load_state_dict(checkpoint) print("Successfully loaded pre-trained model weights.") except Exception as e: print(f"Error loading model weights: {e}") # 优化器配置 if C.optim.name == "Adam": optim = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, amsgrad=C.optim.amsgrad, ) elif C.optim.name == "SGD": optim = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, momentum=C.optim.momentum, ) else: raise NotImplementedError # 输出目录 outdir = osp.join( osp.expanduser(C.io.logdir), f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}" ) os.makedirs(outdir, exist_ok=True) try: trainer = lcnn.trainer.Trainer( device=device, model=model, optimizer=optim, train_loader=train_loader, val_loader=val_loader, out=outdir, ) print("Starting training...") trainer.train() print("Training completed.") except BaseException: if len(glob.glob(f"{outdir}/viz/*")) <= 1: shutil.rmtree(outdir) raise if __name__ == "__main__": main()