#!/usr/bin/env python3 import datetime import glob import os import os.path as osp import platform import pprint import random import shlex import shutil import subprocess import sys import numpy as np import torch import torchvision import yaml import lcnn from lcnn.config import C, M from lcnn.datasets import WireframeDataset, collate from lcnn.models.line_vectorizer import LineVectorizer from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner from torchvision.models import resnet50 def print_model_structure(model): """ 详细打印模型结构和参数 """ print("\n========= Model Structure =========") # 打印模型总体信息 print("Model Type:", type(model)) # 打印模型总参数量 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nTotal Parameters: {total_params:,}") print(f"Trainable Parameters: {trainable_params:,}") print(f"Non-trainable Parameters: {total_params - trainable_params:,}") # 打印每个模块的参数量和可训练状态 print("\n===== Detailed Model Components =====") for name, module in model.named_children(): module_params = sum(p.numel() for p in module.parameters()) module_trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad) print(f"\nmodel.named:{name}:") print(f" Total Parameters: {module_params:,}") print(f" Trainable Parameters: {module_trainable_params:,}") # 打印子模块 for subname, submodule in module.named_children(): sub_params = sum(p.numel() for p in submodule.parameters()) sub_trainable_params = sum(p.numel() for p in submodule.parameters() if p.requires_grad) print(f" {subname}:") print(f" Total Parameters: {sub_params:,}") print(f" Trainable Parameters: {sub_trainable_params:,}") def verify_freeze_params(model, freeze_config): """ 验证参数冻结是否生效 """ print("\n===== Verifying Parameter Freezing =====") for name, module in model.named_children(): if name in freeze_config: if freeze_config[name]: print(f"\nChecking module: {name}") for param_name, param in module.named_parameters(): print(f" {param_name}: requires_grad = {param.requires_grad}") # 特别处理fc2子模块 if name == 'fc2' and 'fc2_submodules' in freeze_config: for subname, submodule in module.named_children(): if subname in freeze_config['fc2_submodules']: if freeze_config['fc2_submodules'][subname]: print(f"\nChecking fc2 submodule: {subname}") for param_name, param in submodule.named_parameters(): print(f" {param_name}: requires_grad = {param.requires_grad}") def freeze_params(model, freeze_config=None): """ 更精细的参数冻结方法 Args: model: 要冻结参数的模型 freeze_config: 冻结配置字典 """ # 默认冻结配置 default_config = { 'backbone': False, 'fc1': False, 'fc2': False, 'fc2_submodules': { '0': False, # fc2的第一个子模块 '2': False, # fc2的第三个子模块 '4': False # fc2的第五个子模块 }, 'pooling': False, 'loss': False } # 更新默认配置 if freeze_config is not None: for key, value in freeze_config.items(): if isinstance(value, dict): default_config[key].update(value) else: default_config[key] = value print("\n===== Parameter Freezing Configuration =====") for name, module in model.named_children(): # 处理主模块冻结 if name in default_config: for param in module.parameters(): param.requires_grad = not default_config[name] if not default_config[name]: print(f"Module {name} is trainable") else: print(f"Freezing module: {name}") # 处理fc2的子模块 if name == 'fc2' and 'fc2_submodules' in default_config: for subname, submodule in module.named_children(): if subname in default_config['fc2_submodules']: for param in submodule.parameters(): param.requires_grad = not default_config['fc2_submodules'][subname] if not default_config['fc2_submodules'][subname]: print(f"Submodule fc2.{subname} is trainable") else: print(f"Freezing submodule: fc2.{subname}") # 打印参数冻结后的详细信息 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nTotal Parameters: {total_params:,}") print(f"Trainable Parameters: {trainable_params:,}") print(f"Frozen Parameters: {total_params - trainable_params:,}") def get_model(num_classes): # 加载预训练的ResNet-50 FPN backbone model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # 获取分类器的输入特征数 in_features = model.roi_heads.box_predictor.cls_score.in_features # 替换分类器以适应新的类别数量 model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) return model def main(): # 训练配置参数 config = { # 数据集配置 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录 'config_file': 'config/wireframe.yaml', # 配置文件路径 # GPU配置 'devices': '0', # 使用的GPU设备 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet # 预训练模型路径 # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径 # 详细的参数冻结配置 冻结是True 'freeze_config': { 'backbone': False, # 冻结backbone 'fc1': False, # 不冻结fc1 'fc2': False, # 不冻结fc2 'fc2_submodules': { '0': False, # fc2的第一个子模块保持可训练 '2': False, # 冻结fc2的第三个子模块 '4': False # fc2的第五个子模块保持可训练 }, 'pooling': False, # 不冻结pooling 'loss': False # 不冻结loss } } # 更新配置 C.update(C.from_yaml(filename=config['config_file'])) M.update(C.model) # 设置随机数种子 random.seed(0) np.random.seed(0) torch.manual_seed(0) # 设备配置 device_name = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = config['devices'] if torch.cuda.is_available(): device_name = "cuda" torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(0) print("Let's use", torch.cuda.device_count(), "GPU(s)!") else: print("CUDA is not available") device = torch.device(device_name) # 数据加载 kwargs = { "collate_fn": collate, "num_workers": C.io.num_workers if os.name != "nt" else 0, "pin_memory": True, } train_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="train"), shuffle=True, batch_size=M.batch_size, **kwargs, ) val_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="val"), shuffle=False, batch_size=M.batch_size_eval, **kwargs, ) # 构建模型 if M.backbone == "stacked_hourglass": print(f"backbone == stacked_hourglass") model = lcnn.models.hg( depth=M.depth, head=MultitaskHead, num_stacks=M.num_stacks, num_blocks=M.num_blocks, num_classes=sum(sum(M.head_size, [])), ) print(f"model.shape:{model}") model = MultitaskLearner(model) model = LineVectorizer(model) elif M.backbone == "unet": print(f"backbone == unet") # weights_backbone = ResNet50_Weights.verify(weights_backbone) model = lcnn.models.unet( num_classes=sum(sum(M.head_size, [])), num_stacks=M.num_stacks, base_channels=kwargs.get("base_channels", 64) ) model = MultitaskLearner(model) model = LineVectorizer(model) elif M.backbone == "resnet50": print(f"backbone == resnet50") model = lcnn.models.resnet50( # num_stacks=M.num_stacks, num_classes=sum(sum(M.head_size, [])), ) model = MultitaskLearner(model) model = LineVectorizer(model) elif M.backbone == "resnet501": print(f"backbone == resnet501") model = lcnn.models.resnet501( # num_stacks=M.num_stacks, num_classes=sum(sum(M.head_size, [])), ) model = MultitaskLearner(model) model = LineVectorizer(model) elif M.backbone == "fasterrcnn_resnet50": print(f"backbone == fasterrcnn_resnet50") model = lcnn.models.fasterrcnn_resnet50( # num_stacks=M.num_stacks, num_classes=sum(sum(M.head_size, [])), ) model = MultitaskLearner(model) model = LineVectorizer(model) else: raise NotImplementedError # 加载预训练权重 try: # 加载模型权重 checkpoint = torch.load(config['pretrained_model'], map_location=device) # 根据实际的检查点结构选择加载方式 if 'model_state_dict' in checkpoint: # 如果是完整的检查点 model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: # 如果是只有状态字典的检查点 model.load_state_dict(checkpoint['state_dict']) else: # 直接加载权重字典 model.load_state_dict(checkpoint) print("Successfully loaded pre-trained model weights.") except Exception as e: print(f"Error loading model weights: {e}") # 打印模型结构 # print_model_structure(model) # # 冻结参数 # freeze_params( # model, # freeze_config=config['freeze_config'] # ) # # 验证冻结参数 # verify_freeze_params(model, config['freeze_config']) # # # 打印模型结构 # print("\n========= After Freezing Backbone =========") # print_model_structure(model) # 移动到设备 model = model.to(device) # 优化器配置 if C.optim.name == "Adam": optim = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, amsgrad=C.optim.amsgrad, ) elif C.optim.name == "SGD": optim = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, momentum=C.optim.momentum, ) else: raise NotImplementedError # 输出目录 outdir = osp.join( osp.expanduser(C.io.logdir), f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}" ) os.makedirs(outdir, exist_ok=True) try: trainer = lcnn.trainer.Trainer( device=device, model=model, optimizer=optim, train_loader=train_loader, val_loader=val_loader, out=outdir, ) print("Starting training...") trainer.train() print("Training completed.") except BaseException: if len(glob.glob(f"{outdir}/viz/*")) <= 1: shutil.rmtree(outdir) raise if __name__ == "__main__": main()