| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- #!/usr/bin/env python3
- """Train L-CNN
- Usage:
- train.py [options] <yaml-config>
- train.py (-h | --help )
- Arguments:
- <yaml-config> Path to the yaml hyper-parameter file
- Options:
- -h --help Show this screen.
- -d --devices <devices> Comma seperated GPU devices [default: 0]
- -i --identifier <identifier> Folder identifier [default: default-identifier]
- """
- import datetime
- import glob
- import os
- import os.path as osp
- import platform
- import pprint
- import random
- import shlex
- import shutil
- import signal
- import subprocess
- import sys
- import threading
- import numpy as np
- import torch
- import yaml
- from docopt import docopt
- import lcnn
- from lcnn.config import C, M
- from lcnn.datasets import WireframeDataset, collate
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- def git_hash():
- cmd = 'git log -n 1 --pretty="%h"'
- ret = subprocess.check_output(shlex.split(cmd)).strip()
- if isinstance(ret, bytes):
- ret = ret.decode()
- return ret
- def get_outdir(identifier):
- # load config
- name = str(datetime.datetime.now().strftime("%y%m%d-%H%M%S"))
- name += "-%s" % git_hash()
- name += "-%s" % identifier
- outdir = osp.join(osp.expanduser(C.io.logdir), name)
- if not osp.exists(outdir):
- os.makedirs(outdir)
- C.io.resume_from = outdir
- C.to_yaml(osp.join(outdir, "config.yaml"))
- os.system(f"git diff HEAD > {outdir}/gitdiff.patch")
- return outdir
- def main():
- args = docopt(__doc__)
- config_file = args["<yaml-config>"] or "config/wireframe.yaml"
- C.update(C.from_yaml(filename=config_file))
- M.update(C.model)
- pprint.pprint(C, indent=4)
- resume_from = C.io.resume_from
- # WARNING: L-CNN is still not deterministic
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
- # 1. dataset
- # uncomment for debug DataLoader
- # wireframe.datasets.WireframeDataset(datadir, split="train")[0]
- # sys.exit(0)
- datadir = C.io.datadir
- kwargs = {
- "collate_fn": collate,
- "num_workers": C.io.num_workers if os.name != "nt" else 0,
- "pin_memory": True,
- }
- train_loader = torch.utils.data.DataLoader(
- WireframeDataset(datadir, split="train"),
- shuffle=True,
- batch_size=M.batch_size,
- **kwargs,
- )
- val_loader = torch.utils.data.DataLoader(
- WireframeDataset(datadir, split="valid"),
- shuffle=False,
- batch_size=M.batch_size_eval,
- **kwargs,
- )
- epoch_size = len(train_loader)
- # print("epoch_size (train):", epoch_size)
- # print("epoch_size (valid):", len(val_loader))
- if resume_from:
- checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth"))
- # 2. model
- if M.backbone == "stacked_hourglass":
- model = lcnn.models.hg(
- depth=M.depth,
- head=MultitaskHead,
- num_stacks=M.num_stacks,
- num_blocks=M.num_blocks,
- num_classes=sum(sum(M.head_size, [])),
- )
- else:
- raise NotImplementedError
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- if resume_from:
- model.load_state_dict(checkpoint["model_state_dict"])
- model = model.to(device)
- # 3. optimizer
- if C.optim.name == "Adam":
- optim = torch.optim.Adam(
- model.parameters(),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- amsgrad=C.optim.amsgrad,
- )
- elif C.optim.name == "SGD":
- optim = torch.optim.SGD(
- model.parameters(),
- lr=C.optim.lr,
- weight_decay=C.optim.weight_decay,
- momentum=C.optim.momentum,
- )
- else:
- raise NotImplementedError
- if resume_from:
- optim.load_state_dict(checkpoint["optim_state_dict"])
- outdir = resume_from or get_outdir(args["--identifier"])
- print("outdir:", outdir)
- try:
- trainer = lcnn.trainer.Trainer(
- device=device,
- model=model,
- optimizer=optim,
- train_loader=train_loader,
- val_loader=val_loader,
- out=outdir,
- )
- if resume_from:
- trainer.iteration = checkpoint["iteration"]
- if trainer.iteration % epoch_size != 0:
- print("WARNING: iteration is not a multiple of epoch_size, reset it")
- trainer.iteration -= trainer.iteration % epoch_size
- trainer.best_mean_loss = checkpoint["best_mean_loss"]
- del checkpoint
- trainer.train()
- except BaseException:
- if len(glob.glob(f"{outdir}/viz/*")) <= 1:
- shutil.rmtree(outdir)
- raise
- if __name__ == "__main__":
- main()
|