train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. """Train L-CNN
  3. Usage:
  4. train.py [options] <yaml-config>
  5. train.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. Options:
  9. -h --help Show this screen.
  10. -d --devices <devices> Comma seperated GPU devices [default: 0]
  11. -i --identifier <identifier> Folder identifier [default: default-identifier]
  12. """
  13. import os
  14. import sys
  15. import glob
  16. import shlex
  17. import pprint
  18. import random
  19. import shutil
  20. import signal
  21. import os.path as osp
  22. import datetime
  23. import platform
  24. import threading
  25. import subprocess
  26. import yaml
  27. import numpy as np
  28. import torch
  29. from docopt import docopt
  30. import lcnn
  31. from lcnn.config import C, M
  32. from lcnn.datasets import WireframeDataset, collate
  33. from lcnn.models.line_vectorizer import LineVectorizer
  34. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  35. def git_hash():
  36. cmd = 'git log -n 1 --pretty="%h"'
  37. ret = subprocess.check_output(shlex.split(cmd)).strip()
  38. if isinstance(ret, bytes):
  39. ret = ret.decode()
  40. return ret
  41. def get_outdir(identifier):
  42. # load config
  43. name = str(datetime.datetime.now().strftime("%y%m%d-%H%M%S"))
  44. name += "-%s" % git_hash()
  45. name += "-%s" % identifier
  46. outdir = osp.join(osp.expanduser(C.io.logdir), name)
  47. if not osp.exists(outdir):
  48. os.makedirs(outdir)
  49. C.io.resume_from = outdir
  50. C.to_yaml(osp.join(outdir, "config.yaml"))
  51. os.system(f"git diff HEAD > {outdir}/gitdiff.patch")
  52. return outdir
  53. def main():
  54. args = docopt(__doc__)
  55. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  56. C.update(C.from_yaml(filename=config_file))
  57. M.update(C.model)
  58. pprint.pprint(C, indent=4)
  59. resume_from = C.io.resume_from
  60. # WARNING: L-CNN is still not deterministic
  61. random.seed(0)
  62. np.random.seed(0)
  63. torch.manual_seed(0)
  64. device_name = "cpu"
  65. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  66. if torch.cuda.is_available():
  67. device_name = "cuda"
  68. torch.backends.cudnn.deterministic = True
  69. torch.cuda.manual_seed(0)
  70. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  71. else:
  72. print("CUDA is not available")
  73. device = torch.device(device_name)
  74. # 1. dataset
  75. # uncomment for debug DataLoader
  76. # wireframe.datasets.WireframeDataset(datadir, split="train")[0]
  77. # sys.exit(0)
  78. datadir = C.io.datadir
  79. kwargs = {
  80. "batch_size": M.batch_size,
  81. "collate_fn": collate,
  82. "num_workers": C.io.num_workers,
  83. "pin_memory": True,
  84. }
  85. train_loader = torch.utils.data.DataLoader(
  86. WireframeDataset(datadir, split="train"), shuffle=True, **kwargs
  87. )
  88. val_loader = torch.utils.data.DataLoader(
  89. WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs
  90. )
  91. epoch_size = len(train_loader)
  92. # print("epoch_size (train):", epoch_size)
  93. # print("epoch_size (valid):", len(val_loader))
  94. if resume_from:
  95. checkpoint = torch.load(osp.join(resume_from, "checkpoint_lastest.pth.tar"))
  96. # 2. model
  97. if M.backbone == "stacked_hourglass":
  98. model = lcnn.models.hg(
  99. depth=M.depth,
  100. head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  101. num_stacks=M.num_stacks,
  102. num_blocks=M.num_blocks,
  103. num_classes=sum(sum(M.head_size, [])),
  104. )
  105. else:
  106. raise NotImplementedError
  107. model = MultitaskLearner(model)
  108. model = LineVectorizer(model)
  109. if resume_from:
  110. model.load_state_dict(checkpoint["model_state_dict"])
  111. model = model.to(device)
  112. # 3. optimizer
  113. if C.optim.name == "Adam":
  114. optim = torch.optim.Adam(
  115. model.parameters(),
  116. lr=C.optim.lr,
  117. weight_decay=C.optim.weight_decay,
  118. amsgrad=C.optim.amsgrad,
  119. )
  120. elif C.optim.name == "SGD":
  121. optim = torch.optim.SGD(
  122. model.parameters(),
  123. lr=C.optim.lr,
  124. weight_decay=C.optim.weight_decay,
  125. momentum=C.optim.momentum,
  126. )
  127. else:
  128. raise NotImplementedError
  129. if resume_from:
  130. optim.load_state_dict(checkpoint["optim_state_dict"])
  131. outdir = resume_from or get_outdir(args["--identifier"])
  132. print("outdir:", outdir)
  133. try:
  134. trainer = lcnn.trainer.Trainer(
  135. device=device,
  136. model=model,
  137. optimizer=optim,
  138. train_loader=train_loader,
  139. val_loader=val_loader,
  140. out=outdir,
  141. )
  142. if resume_from:
  143. trainer.iteration = checkpoint["iteration"]
  144. if trainer.iteration % epoch_size != 0:
  145. print("WARNING: iteration is not a multiple of epoch_size, reset it")
  146. trainer.iteration -= trainer.iteration % epoch_size
  147. trainer.best_mean_loss = checkpoint["best_mean_loss"]
  148. del checkpoint
  149. trainer.train()
  150. except BaseException:
  151. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  152. shutil.rmtree(outdir)
  153. raise
  154. if __name__ == "__main__":
  155. main()