trainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import atexit
  2. import os
  3. import os.path as osp
  4. import shutil
  5. import signal
  6. import subprocess
  7. import threading
  8. import time
  9. from timeit import default_timer as timer
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import torch
  14. import torch.nn.functional as F
  15. from skimage import io
  16. from tensorboardX import SummaryWriter
  17. from lcnn.config import C, M
  18. from lcnn.utils import recursive_to
  19. import matplotlib
  20. from 冻结参数训练 import verify_freeze_params
  21. import os
  22. from torchvision.utils import draw_bounding_boxes
  23. from torchvision import transforms
  24. from .postprocess import postprocess
  25. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
  26. # matplotlib.use('Agg') # 使用无窗口后端
  27. # 绘图
  28. def show_line(img, pred, epoch, writer):
  29. im = img.permute(1, 2, 0)
  30. writer.add_image("ori", im, epoch, dataformats="HWC")
  31. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["box"][0]["boxes"],
  32. colors="yellow", width=1)
  33. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  34. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  35. H = pred["preds"]
  36. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  37. scores = H["score"][0].cpu().numpy()
  38. for i in range(1, len(lines)):
  39. if (lines[i] == lines[0]).all():
  40. lines = lines[:i]
  41. scores = scores[:i]
  42. break
  43. # postprocess lines to remove overlapped lines
  44. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  45. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  46. for i, t in enumerate([0.7]):
  47. plt.gca().set_axis_off()
  48. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  49. plt.margins(0, 0)
  50. for (a, b), s in zip(nlines, nscores):
  51. if s < t:
  52. continue
  53. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  54. plt.scatter(a[1], a[0], **PLTOPTS)
  55. plt.scatter(b[1], b[0], **PLTOPTS)
  56. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  57. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  58. plt.imshow(im)
  59. plt.tight_layout()
  60. fig = plt.gcf()
  61. fig.canvas.draw()
  62. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  63. fig.canvas.get_width_height()[::-1] + (3,))
  64. plt.close()
  65. img2 = transforms.ToTensor()(image_from_plot)
  66. writer.add_image("output", img2, epoch)
  67. class Trainer(object):
  68. def __init__(self, device, model, optimizer, train_loader, val_loader, out):
  69. self.device = device
  70. self.model = model
  71. self.optim = optimizer
  72. self.train_loader = train_loader
  73. self.val_loader = val_loader
  74. self.batch_size = C.model.batch_size
  75. self.validation_interval = C.io.validation_interval
  76. self.out = out
  77. if not osp.exists(self.out):
  78. os.makedirs(self.out)
  79. # self.run_tensorboard()
  80. self.writer = SummaryWriter('logs/')
  81. time.sleep(1)
  82. self.epoch = 0
  83. self.iteration = 0
  84. self.max_epoch = C.optim.max_epoch
  85. self.lr_decay_epoch = C.optim.lr_decay_epoch
  86. self.num_stacks = C.model.num_stacks
  87. self.mean_loss = self.best_mean_loss = 1e1000
  88. self.loss_labels = None
  89. self.avg_metrics = None
  90. self.metrics = np.zeros(0)
  91. self.show_line = show_line
  92. # def run_tensorboard(self):
  93. # board_out = osp.join(self.out, "tensorboard")
  94. # if not osp.exists(board_out):
  95. # os.makedirs(board_out)
  96. # self.writer = SummaryWriter(board_out)
  97. # os.environ["CUDA_VISIBLE_DEVICES"] = ""
  98. # p = subprocess.Popen(
  99. # ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
  100. # )
  101. #
  102. # def killme():
  103. # os.kill(p.pid, signal.SIGTERM)
  104. #
  105. # atexit.register(killme)
  106. def _loss(self, result):
  107. losses = result["losses"]
  108. # Don't move loss label to other place.
  109. # If I want to change the loss, I just need to change this function.
  110. if self.loss_labels is None:
  111. self.loss_labels = ["sum"] + list(losses[0].keys())
  112. self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
  113. print()
  114. print(
  115. "| ".join(
  116. ["progress "]
  117. + list(map("{:7}".format, self.loss_labels))
  118. + ["speed"]
  119. )
  120. )
  121. with open(f"{self.out}/loss.csv", "a") as fout:
  122. print(",".join(["progress"] + self.loss_labels), file=fout)
  123. total_loss = 0
  124. for i in range(self.num_stacks):
  125. for j, name in enumerate(self.loss_labels):
  126. if name == "sum":
  127. continue
  128. if name not in losses[i]:
  129. assert i != 0
  130. continue
  131. loss = losses[i][name].mean()
  132. self.metrics[i, 0] += loss.item()
  133. self.metrics[i, j] += loss.item()
  134. total_loss += loss
  135. return total_loss
  136. def validate(self):
  137. tprint("Running validation...", " " * 75)
  138. training = self.model.training
  139. self.model.eval()
  140. # viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
  141. # npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
  142. # osp.exists(viz) or os.makedirs(viz)
  143. # osp.exists(npz) or os.makedirs(npz)
  144. total_loss = 0
  145. self.metrics[...] = 0
  146. with torch.no_grad():
  147. for batch_idx, (image, meta, target, target_b) in enumerate(self.val_loader):
  148. input_dict = {
  149. "image": recursive_to(image, self.device),
  150. "meta": recursive_to(meta, self.device),
  151. "target": recursive_to(target, self.device),
  152. "target_b": recursive_to(target_b, self.device),
  153. "mode": "validation",
  154. }
  155. result = self.model(input_dict)
  156. # print(f'image:{image.shape}')
  157. # print(result["box"])
  158. # total_loss += self._loss(result)
  159. print(f'self.epoch:{self.epoch}')
  160. # print(result.keys())
  161. self.show_line(image[0], result, self.epoch, self.writer)
  162. # H = result["preds"]
  163. # for i in range(H["jmap"].shape[0]):
  164. # index = batch_idx * M.batch_size_eval + i
  165. # np.savez(
  166. # f"{npz}/{index:06}.npz",
  167. # **{k: v[i].cpu().numpy() for k, v in H.items()},
  168. # )
  169. # if index >= 20:
  170. # continue
  171. # self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
  172. # self._write_metrics(len(self.val_loader), total_loss, "validation", True)
  173. # self.mean_loss = total_loss / len(self.val_loader)
  174. torch.save(
  175. {
  176. "iteration": self.iteration,
  177. "arch": self.model.__class__.__name__,
  178. "optim_state_dict": self.optim.state_dict(),
  179. "model_state_dict": self.model.state_dict(),
  180. "best_mean_loss": self.best_mean_loss,
  181. },
  182. osp.join(self.out, "checkpoint_latest.pth"),
  183. )
  184. # shutil.copy(
  185. # osp.join(self.out, "checkpoint_latest.pth"),
  186. # osp.join(npz, "checkpoint.pth"),
  187. # )
  188. if self.mean_loss < self.best_mean_loss:
  189. self.best_mean_loss = self.mean_loss
  190. shutil.copy(
  191. osp.join(self.out, "checkpoint_latest.pth"),
  192. osp.join(self.out, "checkpoint_best.pth"),
  193. )
  194. if training:
  195. self.model.train1()
  196. def verify_freeze_params(model, freeze_config):
  197. """
  198. 验证参数冻结是否生效
  199. """
  200. print("\n===== Verifying Parameter Freezing =====")
  201. for name, module in model.named_children():
  202. if name in freeze_config:
  203. if freeze_config[name]:
  204. print(f"\nChecking module: {name}")
  205. for param_name, param in module.named_parameters():
  206. print(f" {param_name}: requires_grad = {param.requires_grad}")
  207. # 特别处理fc2子模块
  208. if name == 'fc2' and 'fc2_submodules' in freeze_config:
  209. for subname, submodule in module.named_children():
  210. if subname in freeze_config['fc2_submodules']:
  211. if freeze_config['fc2_submodules'][subname]:
  212. print(f"\nChecking fc2 submodule: {subname}")
  213. for param_name, param in submodule.named_parameters():
  214. print(f" {param_name}: requires_grad = {param.requires_grad}")
  215. def train_epoch(self):
  216. self.model.train1()
  217. time = timer()
  218. for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):
  219. self.optim.zero_grad()
  220. self.metrics[...] = 0
  221. input_dict = {
  222. "image": recursive_to(image, self.device),
  223. "meta": recursive_to(meta, self.device),
  224. "target": recursive_to(target, self.device),
  225. "target_b": recursive_to(target_b, self.device),
  226. "mode": "training",
  227. }
  228. result = self.model(input_dict)
  229. loss = self._loss(result)
  230. if np.isnan(loss.item()):
  231. raise ValueError("loss is nan while training")
  232. loss.backward()
  233. self.optim.step()
  234. if self.avg_metrics is None:
  235. self.avg_metrics = self.metrics
  236. else:
  237. self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
  238. self.iteration += 1
  239. self._write_metrics(1, loss.item(), "training", do_print=False)
  240. if self.iteration % 4 == 0:
  241. tprint(
  242. f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
  243. + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
  244. + f"| {4 * self.batch_size / (timer() - time):04.1f} "
  245. )
  246. time = timer()
  247. num_images = self.batch_size * self.iteration
  248. # if num_images % self.validation_interval == 0 or num_images == 4:
  249. # self.validate()
  250. # time = timer()
  251. self.validate()
  252. # verify_freeze_params()
  253. def _write_metrics(self, size, total_loss, prefix, do_print=False):
  254. for i, metrics in enumerate(self.metrics):
  255. for label, metric in zip(self.loss_labels, metrics):
  256. self.writer.add_scalar(
  257. f"{prefix}/{i}/{label}", metric / size, self.iteration
  258. )
  259. if i == 0 and do_print:
  260. csv_str = (
  261. f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
  262. + ",".join(map("{:.11f}".format, metrics / size))
  263. )
  264. prt_str = (
  265. f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
  266. + "| ".join(map("{:.5f}".format, metrics / size))
  267. )
  268. with open(f"{self.out}/loss.csv", "a") as fout:
  269. print(csv_str, file=fout)
  270. pprint(prt_str, " " * 7)
  271. self.writer.add_scalar(
  272. f"{prefix}/total_loss", total_loss / size, self.iteration
  273. )
  274. return total_loss
  275. def _plot_samples(self, i, index, result, meta, target, prefix):
  276. fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
  277. img = io.imread(fn)
  278. imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
  279. mask_result = result["jmap"][i].cpu().numpy()
  280. mask_target = target["jmap"][i].cpu().numpy()
  281. for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
  282. imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
  283. imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
  284. line_result = result["lmap"][i].cpu().numpy()
  285. line_target = target["lmap"][i].cpu().numpy()
  286. imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
  287. imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
  288. def draw_vecl(lines, sline, juncs, junts, fn):
  289. imshow(img)
  290. if len(lines) > 0 and not (lines[0] == 0).all():
  291. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  292. if i > 0 and (lines[i] == lines[0]).all():
  293. break
  294. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  295. if not (juncs[0] == 0).all():
  296. for i, j in enumerate(juncs):
  297. if i > 0 and (i == juncs[0]).all():
  298. break
  299. plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
  300. if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
  301. for i, j in enumerate(junts):
  302. if i > 0 and (i == junts[0]).all():
  303. break
  304. plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
  305. plt.savefig(fn), plt.close()
  306. junc = meta[i]["junc"].cpu().numpy() * 4
  307. jtyp = meta[i]["jtyp"].cpu().numpy()
  308. juncs = junc[jtyp == 0]
  309. junts = junc[jtyp == 1]
  310. rjuncs = result["juncs"][i].cpu().numpy() * 4
  311. rjunts = None
  312. if "junts" in result:
  313. rjunts = result["junts"][i].cpu().numpy() * 4
  314. lpre = meta[i]["lpre"].cpu().numpy() * 4
  315. vecl_target = meta[i]["lpre_label"].cpu().numpy()
  316. vecl_result = result["lines"][i].cpu().numpy() * 4
  317. score = result["score"][i].cpu().numpy()
  318. lpre = lpre[vecl_target == 1]
  319. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
  320. draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
  321. def train(self):
  322. plt.rcParams["figure.figsize"] = (24, 24)
  323. # if self.iteration == 0:
  324. # self.validate()
  325. epoch_size = len(self.train_loader)
  326. start_epoch = self.iteration // epoch_size
  327. for self.epoch in range(start_epoch, self.max_epoch):
  328. print(f"Epoch {self.epoch}/{C.optim.max_epoch} - Iteration {self.iteration}/{epoch_size}")
  329. if self.epoch == self.lr_decay_epoch:
  330. self.optim.param_groups[0]["lr"] /= 10
  331. self.train_epoch()
  332. cmap = plt.get_cmap("jet")
  333. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  334. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  335. sm.set_array([])
  336. def c(x):
  337. return sm.to_rgba(x)
  338. def imshow(im):
  339. plt.close()
  340. plt.tight_layout()
  341. plt.imshow(im)
  342. plt.colorbar(sm, fraction=0.046)
  343. plt.xlim([0, im.shape[0]])
  344. plt.ylim([im.shape[0], 0])
  345. def tprint(*args):
  346. """Temporarily prints things on the screen"""
  347. print("\r", end="")
  348. print(*args, end="")
  349. def pprint(*args):
  350. """Permanently prints things on the screen"""
  351. print("\r", end="")
  352. print(*args)
  353. def _launch_tensorboard(board_out, port, out):
  354. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  355. p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
  356. def kill():
  357. os.kill(p.pid, signal.SIGTERM)
  358. atexit.register(kill)