train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #!/usr/bin/env python3
  2. """Train L-CNN
  3. Usage:
  4. train.py [options] <yaml-config>
  5. train.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. Options:
  9. -h --help Show this screen.
  10. -d --devices <devices> Comma seperated GPU devices [default: 0]
  11. -i --identifier <identifier> Folder identifier [default: default-identifier]
  12. """
  13. import datetime
  14. import glob
  15. import os
  16. import os.path as osp
  17. import platform
  18. import pprint
  19. import random
  20. import shlex
  21. import shutil
  22. import signal
  23. import subprocess
  24. import sys
  25. import threading
  26. import numpy as np
  27. import torch
  28. import yaml
  29. from docopt import docopt
  30. import lcnn
  31. from lcnn.config import C, M
  32. from lcnn.datasets import WireframeDataset, collate
  33. from lcnn.models.line_vectorizer import LineVectorizer
  34. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  35. def git_hash():
  36. cmd = 'git log -n 1 --pretty="%h"'
  37. ret = subprocess.check_output(shlex.split(cmd)).strip()
  38. if isinstance(ret, bytes):
  39. ret = ret.decode()
  40. return ret
  41. def get_outdir(identifier):
  42. # load config
  43. name = str(datetime.datetime.now().strftime("%y%m%d-%H%M%S"))
  44. name += "-%s" % git_hash()
  45. name += "-%s" % identifier
  46. outdir = osp.join(osp.expanduser(C.io.logdir), name)
  47. if not osp.exists(outdir):
  48. os.makedirs(outdir)
  49. C.io.resume_from = outdir
  50. C.to_yaml(osp.join(outdir, "config.yaml"))
  51. os.system(f"git diff HEAD > {outdir}/gitdiff.patch")
  52. return outdir
  53. def main():
  54. args = docopt(__doc__)
  55. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  56. C.update(C.from_yaml(filename=config_file))
  57. M.update(C.model)
  58. pprint.pprint(C, indent=4)
  59. resume_from = C.io.resume_from
  60. # WARNING: L-CNN is still not deterministic
  61. random.seed(0)
  62. np.random.seed(0)
  63. torch.manual_seed(0)
  64. device_name = "cpu"
  65. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  66. if torch.cuda.is_available():
  67. device_name = "cuda"
  68. torch.backends.cudnn.deterministic = True
  69. torch.cuda.manual_seed(0)
  70. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  71. else:
  72. print("CUDA is not available")
  73. device = torch.device(device_name)
  74. # 1. dataset
  75. # uncomment for debug DataLoader
  76. # wireframe.datasets.WireframeDataset(datadir, split="train")[0]
  77. # sys.exit(0)
  78. datadir = C.io.datadir
  79. kwargs = {
  80. "collate_fn": collate,
  81. "num_workers": C.io.num_workers if os.name != "nt" else 0,
  82. "pin_memory": True,
  83. }
  84. train_loader = torch.utils.data.DataLoader(
  85. WireframeDataset(datadir, split="train"),
  86. shuffle=True,
  87. batch_size=M.batch_size,
  88. **kwargs,
  89. )
  90. val_loader = torch.utils.data.DataLoader(
  91. WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
  92. )
  93. epoch_size = len(train_loader)
  94. # print("epoch_size (train):", epoch_size)
  95. # print("epoch_size (valid):", len(val_loader))
  96. if resume_from:
  97. checkpoint = torch.load(osp.join(resume_from, "checkpoint_latest.pth.tar"))
  98. # 2. model
  99. if M.backbone == "stacked_hourglass":
  100. model = lcnn.models.hg(
  101. depth=M.depth,
  102. head=MultitaskHead,
  103. num_stacks=M.num_stacks,
  104. num_blocks=M.num_blocks,
  105. num_classes=sum(sum(M.head_size, [])),
  106. )
  107. else:
  108. raise NotImplementedError
  109. model = MultitaskLearner(model)
  110. model = LineVectorizer(model)
  111. if resume_from:
  112. model.load_state_dict(checkpoint["model_state_dict"])
  113. model = model.to(device)
  114. # 3. optimizer
  115. if C.optim.name == "Adam":
  116. optim = torch.optim.Adam(
  117. model.parameters(),
  118. lr=C.optim.lr,
  119. weight_decay=C.optim.weight_decay,
  120. amsgrad=C.optim.amsgrad,
  121. )
  122. elif C.optim.name == "SGD":
  123. optim = torch.optim.SGD(
  124. model.parameters(),
  125. lr=C.optim.lr,
  126. weight_decay=C.optim.weight_decay,
  127. momentum=C.optim.momentum,
  128. )
  129. else:
  130. raise NotImplementedError
  131. if resume_from:
  132. optim.load_state_dict(checkpoint["optim_state_dict"])
  133. outdir = resume_from or get_outdir(args["--identifier"])
  134. print("outdir:", outdir)
  135. try:
  136. trainer = lcnn.trainer.Trainer(
  137. device=device,
  138. model=model,
  139. optimizer=optim,
  140. train_loader=train_loader,
  141. val_loader=val_loader,
  142. out=outdir,
  143. )
  144. if resume_from:
  145. trainer.iteration = checkpoint["iteration"]
  146. if trainer.iteration % epoch_size != 0:
  147. print("WARNING: iteration is not a multiple of epoch_size, reset it")
  148. trainer.iteration -= trainer.iteration % epoch_size
  149. trainer.best_mean_loss = checkpoint["best_mean_loss"]
  150. del checkpoint
  151. trainer.train()
  152. except BaseException:
  153. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  154. shutil.rmtree(outdir)
  155. raise
  156. if __name__ == "__main__":
  157. main()