trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import atexit
  2. import os
  3. import os.path as osp
  4. import shutil
  5. import signal
  6. import subprocess
  7. import threading
  8. import time
  9. from timeit import default_timer as timer
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import torch
  14. import torch.nn.functional as F
  15. from skimage import io
  16. from tensorboardX import SummaryWriter
  17. from lcnn.config import C, M
  18. from lcnn.utils import recursive_to
  19. class Trainer(object):
  20. def __init__(self, device, model, optimizer, train_loader, val_loader, out):
  21. self.device = device
  22. self.model = model
  23. self.optim = optimizer
  24. self.train_loader = train_loader
  25. self.val_loader = val_loader
  26. self.batch_size = C.model.batch_size
  27. self.validation_interval = C.io.validation_interval
  28. self.out = out
  29. if not osp.exists(self.out):
  30. os.makedirs(self.out)
  31. self.run_tensorboard()
  32. time.sleep(1)
  33. self.epoch = 0
  34. self.iteration = 0
  35. self.max_epoch = C.optim.max_epoch
  36. self.lr_decay_epoch = C.optim.lr_decay_epoch
  37. self.num_stacks = C.model.num_stacks
  38. self.mean_loss = self.best_mean_loss = 1e1000
  39. self.loss_labels = None
  40. self.avg_metrics = None
  41. self.metrics = np.zeros(0)
  42. def run_tensorboard(self):
  43. board_out = osp.join(self.out, "tensorboard")
  44. if not osp.exists(board_out):
  45. os.makedirs(board_out)
  46. self.writer = SummaryWriter(board_out)
  47. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  48. p = subprocess.Popen(
  49. ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
  50. )
  51. def killme():
  52. os.kill(p.pid, signal.SIGTERM)
  53. atexit.register(killme)
  54. def _loss(self, result):
  55. losses = result["losses"]
  56. # Don't move loss label to other place.
  57. # If I want to change the loss, I just need to change this function.
  58. if self.loss_labels is None:
  59. self.loss_labels = ["sum"] + list(losses[0].keys())
  60. self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
  61. print()
  62. print(
  63. "| ".join(
  64. ["progress "]
  65. + list(map("{:7}".format, self.loss_labels))
  66. + ["speed"]
  67. )
  68. )
  69. with open(f"{self.out}/loss.csv", "a") as fout:
  70. print(",".join(["progress"] + self.loss_labels), file=fout)
  71. total_loss = 0
  72. for i in range(self.num_stacks):
  73. for j, name in enumerate(self.loss_labels):
  74. if name == "sum":
  75. continue
  76. if name not in losses[i]:
  77. assert i != 0
  78. continue
  79. loss = losses[i][name].mean()
  80. self.metrics[i, 0] += loss.item()
  81. self.metrics[i, j] += loss.item()
  82. total_loss += loss
  83. return total_loss
  84. def validate(self):
  85. tprint("Running validation...", " " * 75)
  86. training = self.model.training
  87. self.model.eval()
  88. viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
  89. npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
  90. osp.exists(viz) or os.makedirs(viz)
  91. osp.exists(npz) or os.makedirs(npz)
  92. total_loss = 0
  93. self.metrics[...] = 0
  94. with torch.no_grad():
  95. for batch_idx, (image, meta, target) in enumerate(self.val_loader):
  96. input_dict = {
  97. "image": recursive_to(image, self.device),
  98. "meta": recursive_to(meta, self.device),
  99. "target": recursive_to(target, self.device),
  100. "mode": "validation",
  101. }
  102. result = self.model(input_dict)
  103. total_loss += self._loss(result)
  104. H = result["preds"]
  105. for i in range(H["jmap"].shape[0]):
  106. index = batch_idx * M.batch_size_eval + i
  107. np.savez(
  108. f"{npz}/{index:06}.npz",
  109. **{k: v[i].cpu().numpy() for k, v in H.items()},
  110. )
  111. if index >= 20:
  112. continue
  113. self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
  114. self._write_metrics(len(self.val_loader), total_loss, "validation", True)
  115. self.mean_loss = total_loss / len(self.val_loader)
  116. torch.save(
  117. {
  118. "iteration": self.iteration,
  119. "arch": self.model.__class__.__name__,
  120. "optim_state_dict": self.optim.state_dict(),
  121. "model_state_dict": self.model.state_dict(),
  122. "best_mean_loss": self.best_mean_loss,
  123. },
  124. osp.join(self.out, "checkpoint_latest.pth.tar"),
  125. )
  126. shutil.copy(
  127. osp.join(self.out, "checkpoint_latest.pth.tar"),
  128. osp.join(npz, "checkpoint.pth.tar"),
  129. )
  130. if self.mean_loss < self.best_mean_loss:
  131. self.best_mean_loss = self.mean_loss
  132. shutil.copy(
  133. osp.join(self.out, "checkpoint_latest.pth.tar"),
  134. osp.join(self.out, "checkpoint_best.pth.tar"),
  135. )
  136. if training:
  137. self.model.train()
  138. def train_epoch(self):
  139. self.model.train()
  140. time = timer()
  141. for batch_idx, (image, meta, target) in enumerate(self.train_loader):
  142. self.optim.zero_grad()
  143. self.metrics[...] = 0
  144. input_dict = {
  145. "image": recursive_to(image, self.device),
  146. "meta": recursive_to(meta, self.device),
  147. "target": recursive_to(target, self.device),
  148. "mode": "training",
  149. }
  150. result = self.model(input_dict)
  151. loss = self._loss(result)
  152. if np.isnan(loss.item()):
  153. raise ValueError("loss is nan while training")
  154. loss.backward()
  155. self.optim.step()
  156. if self.avg_metrics is None:
  157. self.avg_metrics = self.metrics
  158. else:
  159. self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
  160. self.iteration += 1
  161. self._write_metrics(1, loss.item(), "training", do_print=False)
  162. if self.iteration % 4 == 0:
  163. tprint(
  164. f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
  165. + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
  166. + f"| {4 * self.batch_size / (timer() - time):04.1f} "
  167. )
  168. time = timer()
  169. num_images = self.batch_size * self.iteration
  170. if num_images % self.validation_interval == 0 or num_images == 600:
  171. self.validate()
  172. time = timer()
  173. def _write_metrics(self, size, total_loss, prefix, do_print=False):
  174. for i, metrics in enumerate(self.metrics):
  175. for label, metric in zip(self.loss_labels, metrics):
  176. self.writer.add_scalar(
  177. f"{prefix}/{i}/{label}", metric / size, self.iteration
  178. )
  179. if i == 0 and do_print:
  180. csv_str = (
  181. f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
  182. + ",".join(map("{:.11f}".format, metrics / size))
  183. )
  184. prt_str = (
  185. f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
  186. + "| ".join(map("{:.5f}".format, metrics / size))
  187. )
  188. with open(f"{self.out}/loss.csv", "a") as fout:
  189. print(csv_str, file=fout)
  190. pprint(prt_str, " " * 7)
  191. self.writer.add_scalar(
  192. f"{prefix}/total_loss", total_loss / size, self.iteration
  193. )
  194. return total_loss
  195. def _plot_samples(self, i, index, result, meta, target, prefix):
  196. fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
  197. img = io.imread(fn)
  198. imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
  199. mask_result = result["jmap"][i].cpu().numpy()
  200. mask_target = target["jmap"][i].cpu().numpy()
  201. for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
  202. imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
  203. imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
  204. line_result = result["lmap"][i].cpu().numpy()
  205. line_target = target["lmap"][i].cpu().numpy()
  206. imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
  207. imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
  208. def draw_vecl(lines, sline, juncs, junts, fn):
  209. imshow(img)
  210. if len(lines) > 0 and not (lines[0] == 0).all():
  211. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  212. if i > 0 and (lines[i] == lines[0]).all():
  213. break
  214. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  215. if not (juncs[0] == 0).all():
  216. for i, j in enumerate(juncs):
  217. if i > 0 and (i == juncs[0]).all():
  218. break
  219. plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
  220. if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
  221. for i, j in enumerate(junts):
  222. if i > 0 and (i == junts[0]).all():
  223. break
  224. plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
  225. plt.savefig(fn), plt.close()
  226. junc = meta[i]["junc"].cpu().numpy() * 4
  227. jtyp = meta[i]["jtyp"].cpu().numpy()
  228. juncs = junc[jtyp == 0]
  229. junts = junc[jtyp == 1]
  230. rjuncs = result["juncs"][i].cpu().numpy() * 4
  231. rjunts = None
  232. if "junts" in result:
  233. rjunts = result["junts"][i].cpu().numpy() * 4
  234. lpre = meta[i]["lpre"].cpu().numpy() * 4
  235. vecl_target = meta[i]["lpre_label"].cpu().numpy()
  236. vecl_result = result["lines"][i].cpu().numpy() * 4
  237. score = result["score"][i].cpu().numpy()
  238. lpre = lpre[vecl_target == 1]
  239. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
  240. draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
  241. def train(self):
  242. plt.rcParams["figure.figsize"] = (24, 24)
  243. # if self.iteration == 0:
  244. # self.validate()
  245. epoch_size = len(self.train_loader)
  246. start_epoch = self.iteration // epoch_size
  247. for self.epoch in range(start_epoch, self.max_epoch):
  248. if self.epoch == self.lr_decay_epoch:
  249. self.optim.param_groups[0]["lr"] /= 10
  250. self.train_epoch()
  251. cmap = plt.get_cmap("jet")
  252. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  253. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  254. sm.set_array([])
  255. def c(x):
  256. return sm.to_rgba(x)
  257. def imshow(im):
  258. plt.close()
  259. plt.tight_layout()
  260. plt.imshow(im)
  261. plt.colorbar(sm, fraction=0.046)
  262. plt.xlim([0, im.shape[0]])
  263. plt.ylim([im.shape[0], 0])
  264. def tprint(*args):
  265. """Temporarily prints things on the screen"""
  266. print("\r", end="")
  267. print(*args, end="")
  268. def pprint(*args):
  269. """Permanently prints things on the screen"""
  270. print("\r", end="")
  271. print(*args)
  272. def _launch_tensorboard(board_out, port, out):
  273. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  274. p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
  275. def kill():
  276. os.kill(p.pid, signal.SIGTERM)
  277. atexit.register(kill)