123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import datetime
- import glob
- import os
- import os.path as osp
- import platform
- import pprint
- import random
- import shlex
- import shutil
- import subprocess
- import sys
- import numpy as np
- import torch
- import torchvision
- import yaml
- import lcnn
- from lcnn.config import C, M
- from lcnn.datasets import WireframeDataset, collate
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- from torchvision.models import resnet50
- from models.line_detect.line_rcnn import linercnn_resnet50_fpn
- def main():
-
- config = {
-
- 'datadir': r'D:\python\PycharmProjects\data',
- 'config_file': 'config/wireframe.yaml',
-
- 'devices': '0',
- 'identifier': 'fasterrcnn_resnet50',
-
-
- }
-
- C.update(C.from_yaml(filename=config['config_file']))
- M.update(C.model)
-
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
-
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
-
- kwargs = {
- "collate_fn": collate,
- "num_workers": C.io.num_workers if os.name != "nt" else 0,
- "pin_memory": True,
- }
- train_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="train"),
- shuffle=True,
- batch_size=M.batch_size,
- **kwargs,
- )
- val_loader = torch.utils.data.DataLoader(
- WireframeDataset(config['datadir'], dataset_type="val"),
- shuffle=False,
- batch_size=M.batch_size_eval,
- **kwargs,
- )
- model = linercnn_resnet50_fpn().to(device)
-
- try:
-
- checkpoint = torch.load(config['pretrained_model'], map_location=device)
-
- if 'model_state_dict' in checkpoint:
-
- model.load_state_dict(checkpoint['model_state_dict'])
- elif 'state_dict' in checkpoint:
-
- model.load_state_dict(checkpoint['state_dict'])
- else:
-
- model.load_state_dict(checkpoint)
- print("Successfully loaded pre-trained model weights.")
- except Exception as e:
- print(f"Error loading model weights: {e}")
-
- if C.optim.name == "Adam":
- optim = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- amsgrad=C.optim.amsgrad,
- )
- elif C.optim.name == "SGD":
- optim = torch.optim.SGD(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- momentum=C.optim.momentum,
- )
- else:
- raise NotImplementedError
-
- outdir = osp.join(
- osp.expanduser(C.io.logdir),
- f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
- )
- os.makedirs(outdir, exist_ok=True)
- try:
- trainer = lcnn.trainer.Trainer(
- device=device,
- model=model,
- optimizer=optim,
- train_loader=train_loader,
- val_loader=val_loader,
- out=outdir,
- )
- print("Starting training...")
- trainer.train()
- print("Training completed.")
- except BaseException:
- if len(glob.glob(f"{outdir}/viz/*")) <= 1:
- shutil.rmtree(outdir)
- raise
- if __name__ == "__main__":
- main()
|