123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- #!/usr/bin/env python3
- import datetime
- import glob
- import os
- import os.path as osp
- import platform
- import pprint
- import random
- import shlex
- import shutil
- import subprocess
- import sys
- import numpy as np
- import torch
- import torchvision
- import yaml
- import lcnn
- from lcnn.config import C, M
- from lcnn.datasets import WireframeDataset, collate
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- from torchvision.models import resnet50
- def print_model_structure(model):
- """
- 详细打印模型结构和参数
- """
- print("\n========= Model Structure =========")
- # 打印模型总体信息
- print("Model Type:", type(model))
- # 打印模型总参数量
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\nTotal Parameters: {total_params:,}")
- print(f"Trainable Parameters: {trainable_params:,}")
- print(f"Non-trainable Parameters: {total_params - trainable_params:,}")
- # 打印每个模块的参数量和可训练状态
- print("\n===== Detailed Model Components =====")
- for name, module in model.named_children():
- module_params = sum(p.numel() for p in module.parameters())
- module_trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
- print(f"\nmodel.named:{name}:")
- print(f" Total Parameters: {module_params:,}")
- print(f" Trainable Parameters: {module_trainable_params:,}")
- # 打印子模块
- for subname, submodule in module.named_children():
- sub_params = sum(p.numel() for p in submodule.parameters())
- sub_trainable_params = sum(p.numel() for p in submodule.parameters() if p.requires_grad)
- print(f" {subname}:")
- print(f" Total Parameters: {sub_params:,}")
- print(f" Trainable Parameters: {sub_trainable_params:,}")
- def verify_freeze_params(model, freeze_config):
- """
- 验证参数冻结是否生效
- """
- print("\n===== Verifying Parameter Freezing =====")
- for name, module in model.named_children():
- if name in freeze_config:
- if freeze_config[name]:
- print(f"\nChecking module: {name}")
- for param_name, param in module.named_parameters():
- print(f" {param_name}: requires_grad = {param.requires_grad}")
- # 特别处理fc2子模块
- if name == 'fc2' and 'fc2_submodules' in freeze_config:
- for subname, submodule in module.named_children():
- if subname in freeze_config['fc2_submodules']:
- if freeze_config['fc2_submodules'][subname]:
- print(f"\nChecking fc2 submodule: {subname}")
- for param_name, param in submodule.named_parameters():
- print(f" {param_name}: requires_grad = {param.requires_grad}")
- def freeze_params(model, freeze_config=None):
- """
- 更精细的参数冻结方法
- Args:
- model: 要冻结参数的模型
- freeze_config: 冻结配置字典
- """
- # 默认冻结配置
- default_config = {
- 'backbone': False,
- 'fc1': False,
- 'fc2': False,
- 'fc2_submodules': {
- '0': False, # fc2的第一个子模块
- '2': False, # fc2的第三个子模块
- '4': False # fc2的第五个子模块
- },
- 'pooling': False,
- 'loss': False
- }
- # 更新默认配置
- if freeze_config is not None:
- for key, value in freeze_config.items():
- if isinstance(value, dict):
- default_config[key].update(value)
- else:
- default_config[key] = value
- print("\n===== Parameter Freezing Configuration =====")
- for name, module in model.named_children():
- # 处理主模块冻结
- if name in default_config:
- for param in module.parameters():
- param.requires_grad = not default_config[name]
- if not default_config[name]:
- print(f"Module {name} is trainable")
- else:
- print(f"Freezing module: {name}")
- # 处理fc2的子模块
- if name == 'fc2' and 'fc2_submodules' in default_config:
- for subname, submodule in module.named_children():
- if subname in default_config['fc2_submodules']:
- for param in submodule.parameters():
- param.requires_grad = not default_config['fc2_submodules'][subname]
- if not default_config['fc2_submodules'][subname]:
- print(f"Submodule fc2.{subname} is trainable")
- else:
- print(f"Freezing submodule: fc2.{subname}")
- # 打印参数冻结后的详细信息
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\nTotal Parameters: {total_params:,}")
- print(f"Trainable Parameters: {trainable_params:,}")
- print(f"Frozen Parameters: {total_params - trainable_params:,}")
- def get_model(num_classes):
- # 加载预训练的ResNet-50 FPN backbone
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
- # 获取分类器的输入特征数
- in_features = model.roi_heads.box_predictor.cls_score.in_features
- # 替换分类器以适应新的类别数量
- model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
- return model
- def main():
- # 训练配置参数
- config = {
- # 数据集配置
- 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
- 'config_file': 'config/wireframe.yaml', # 配置文件路径
- # GPU配置
- 'devices': '0', # 使用的GPU设备
- 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
- # 预训练模型路径
- # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
- # 详细的参数冻结配置 冻结是True
- 'freeze_config': {
- 'backbone': False, # 冻结backbone
- 'fc1': False, # 不冻结fc1
- 'fc2': False, # 不冻结fc2
- 'fc2_submodules': {
- '0': False, # fc2的第一个子模块保持可训练
- '2': False, # 冻结fc2的第三个子模块
- '4': False # fc2的第五个子模块保持可训练
- },
- 'pooling': False, # 不冻结pooling
- 'loss': False # 不冻结loss
- }
- }
- # 更新配置
- C.update(C.from_yaml(filename=config['config_file']))
- M.update(C.model)
- # 设置随机数种子
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- # 设备配置
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
- # 数据加载
- kwargs = {
- "collate_fn": collate,
- "num_workers": C.io.num_workers if os.name != "nt" else 0,
- "pin_memory": True,
- }
- train_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="train"),
- shuffle=True,
- batch_size=M.batch_size,
- **kwargs,
- )
- val_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="val"),
- shuffle=False,
- batch_size=M.batch_size_eval,
- **kwargs,
- )
- # 构建模型
- if M.backbone == "stacked_hourglass":
- print(f"backbone == stacked_hourglass")
- model = lcnn.models.hg(
- depth=M.depth,
- head=MultitaskHead,
- num_stacks=M.num_stacks,
- num_blocks=M.num_blocks,
- num_classes=sum(sum(M.head_size, [])),
- )
- print(f"model.shape:{model}")
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- elif M.backbone == "unet":
- print(f"backbone == unet")
- # weights_backbone = ResNet50_Weights.verify(weights_backbone)
- model = lcnn.models.unet(
- num_classes=sum(sum(M.head_size, [])),
- num_stacks=M.num_stacks,
- base_channels=kwargs.get("base_channels", 64)
- )
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- elif M.backbone == "resnet50":
- print(f"backbone == resnet50")
- model = lcnn.models.resnet50(
- # num_stacks=M.num_stacks,
- num_classes=sum(sum(M.head_size, [])),
- )
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- elif M.backbone == "resnet501":
- print(f"backbone == resnet501")
- model = lcnn.models.resnet501(
- # num_stacks=M.num_stacks,
- num_classes=sum(sum(M.head_size, [])),
- )
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- elif M.backbone == "fasterrcnn_resnet50":
- print(f"backbone == fasterrcnn_resnet50")
- model = lcnn.models.fasterrcnn_resnet50(
- # num_stacks=M.num_stacks,
- num_classes=sum(sum(M.head_size, [])),
- )
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- else:
- raise NotImplementedError
- # 加载预训练权重
- try:
- # 加载模型权重
- checkpoint = torch.load(config['pretrained_model'], map_location=device)
- # 根据实际的检查点结构选择加载方式
- if 'model_state_dict' in checkpoint:
- # 如果是完整的检查点
- model.load_state_dict(checkpoint['model_state_dict'])
- elif 'state_dict' in checkpoint:
- # 如果是只有状态字典的检查点
- model.load_state_dict(checkpoint['state_dict'])
- else:
- # 直接加载权重字典
- model.load_state_dict(checkpoint)
- print("Successfully loaded pre-trained model weights.")
- except Exception as e:
- print(f"Error loading model weights: {e}")
- # 打印模型结构
- # print_model_structure(model)
- # # 冻结参数
- # freeze_params(
- # model,
- # freeze_config=config['freeze_config']
- # )
- # # 验证冻结参数
- # verify_freeze_params(model, config['freeze_config'])
- #
- # # 打印模型结构
- # print("\n========= After Freezing Backbone =========")
- # print_model_structure(model)
- # 移动到设备
- model = model.to(device)
- # 优化器配置
- if C.optim.name == "Adam":
- optim = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- amsgrad=C.optim.amsgrad,
- )
- elif C.optim.name == "SGD":
- optim = torch.optim.SGD(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- momentum=C.optim.momentum,
- )
- else:
- raise NotImplementedError
- # 输出目录
- outdir = osp.join(
- osp.expanduser(C.io.logdir),
- f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
- )
- os.makedirs(outdir, exist_ok=True)
- try:
- trainer = lcnn.trainer.Trainer(
- device=device,
- model=model,
- optimizer=optim,
- train_loader=train_loader,
- val_loader=val_loader,
- out=outdir,
- )
- print("Starting training...")
- trainer.train()
- print("Training completed.")
- except BaseException:
- if len(glob.glob(f"{outdir}/viz/*")) <= 1:
- shutil.rmtree(outdir)
- raise
- if __name__ == "__main__":
- main()
|