train——line_rcnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # 根据LCNN写的train 2025/2/7
  2. '''
  3. #!/usr/bin/env python3
  4. import datetime
  5. import glob
  6. import os
  7. import os.path as osp
  8. import platform
  9. import pprint
  10. import random
  11. import shlex
  12. import shutil
  13. import subprocess
  14. import sys
  15. import numpy as np
  16. import torch
  17. import torchvision
  18. import yaml
  19. import lcnn
  20. from lcnn.config import C, M
  21. from lcnn.datasets import WireframeDataset, collate
  22. from lcnn.models.line_vectorizer import LineVectorizer
  23. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  24. from torchvision.models import resnet50
  25. from models.line_detect.line_rcnn import linercnn_resnet50_fpn
  26. def main():
  27. # 训练配置参数
  28. config = {
  29. # 数据集配置
  30. 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
  31. 'config_file': 'config/wireframe.yaml', # 配置文件路径
  32. # GPU配置
  33. 'devices': '0', # 使用的GPU设备
  34. 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
  35. # 预训练模型路径
  36. # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
  37. }
  38. # 更新配置
  39. C.update(C.from_yaml(filename=config['config_file']))
  40. M.update(C.model)
  41. # 设置随机数种子
  42. random.seed(0)
  43. np.random.seed(0)
  44. torch.manual_seed(0)
  45. # 设备配置
  46. device_name = "cpu"
  47. os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
  48. if torch.cuda.is_available():
  49. device_name = "cuda"
  50. torch.backends.cudnn.deterministic = True
  51. torch.cuda.manual_seed(0)
  52. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  53. else:
  54. print("CUDA is not available")
  55. device = torch.device(device_name)
  56. # 数据加载
  57. kwargs = {
  58. "collate_fn": collate,
  59. "num_workers": C.io.num_workers if os.name != "nt" else 0,
  60. "pin_memory": True,
  61. }
  62. train_loader = torch.utils.data.DataLoader(
  63. WireframeDataset(config['datadir'], dataset_type="train"),
  64. shuffle=True,
  65. batch_size=M.batch_size,
  66. **kwargs,
  67. )
  68. val_loader = torch.utils.data.DataLoader(
  69. WireframeDataset(config['datadir'], dataset_type="val"),
  70. shuffle=False,
  71. batch_size=M.batch_size_eval,
  72. **kwargs,
  73. )
  74. model = linercnn_resnet50_fpn().to(device)
  75. # 加载预训练权重
  76. try:
  77. # 加载模型权重
  78. checkpoint = torch.load(config['pretrained_model'], map_location=device)
  79. # 根据实际的检查点结构选择加载方式
  80. if 'model_state_dict' in checkpoint:
  81. # 如果是完整的检查点
  82. model.load_state_dict(checkpoint['model_state_dict'])
  83. elif 'state_dict' in checkpoint:
  84. # 如果是只有状态字典的检查点
  85. model.load_state_dict(checkpoint['state_dict'])
  86. else:
  87. # 直接加载权重字典
  88. model.load_state_dict(checkpoint)
  89. print("Successfully loaded pre-trained model weights.")
  90. except Exception as e:
  91. print(f"Error loading model weights: {e}")
  92. # 优化器配置
  93. if C.optim.name == "Adam":
  94. optim = torch.optim.Adam(
  95. filter(lambda p: p.requires_grad, model.parameters()),
  96. lr=C.optim.lr,
  97. weight_decay=C.optim.weight_decay,
  98. amsgrad=C.optim.amsgrad,
  99. )
  100. elif C.optim.name == "SGD":
  101. optim = torch.optim.SGD(
  102. filter(lambda p: p.requires_grad, model.parameters()),
  103. lr=C.optim.lr,
  104. weight_decay=C.optim.weight_decay,
  105. momentum=C.optim.momentum,
  106. )
  107. else:
  108. raise NotImplementedError
  109. # 输出目录
  110. outdir = osp.join(
  111. osp.expanduser(C.io.logdir),
  112. f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
  113. )
  114. os.makedirs(outdir, exist_ok=True)
  115. try:
  116. trainer = lcnn.trainer.Trainer(
  117. device=device,
  118. model=model,
  119. optimizer=optim,
  120. train_loader=train_loader,
  121. val_loader=val_loader,
  122. out=outdir,
  123. )
  124. print("Starting training...")
  125. trainer.train()
  126. print("Training completed.")
  127. except BaseException:
  128. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  129. shutil.rmtree(outdir)
  130. raise
  131. if __name__ == "__main__":
  132. main()
  133. '''
  134. # 2025/2/9
  135. import os
  136. from typing import Optional, Any
  137. import cv2
  138. import numpy as np
  139. import torch
  140. from models.config.config_tool import read_yaml
  141. from models.line_detect.dataset_LD import WirePointDataset
  142. from tools import utils
  143. from torch.utils.tensorboard import SummaryWriter
  144. import matplotlib.pyplot as plt
  145. import matplotlib as mpl
  146. from skimage import io
  147. from models.line_detect.line_rcnn import linercnn_resnet50_fpn
  148. from torchvision.utils import draw_bounding_boxes
  149. from models.wirenet.postprocess import postprocess
  150. from torchvision import transforms
  151. from collections import OrderedDict
  152. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  153. def _loss(losses):
  154. total_loss = 0
  155. for i in losses.keys():
  156. if i != "loss_wirepoint":
  157. total_loss += losses[i]
  158. else:
  159. loss_labels = losses[i]["losses"]
  160. loss_labels_k = list(loss_labels[0].keys())
  161. for j, name in enumerate(loss_labels_k):
  162. loss = loss_labels[0][name].mean()
  163. total_loss += loss
  164. return total_loss
  165. cmap = plt.get_cmap("jet")
  166. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  167. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  168. sm.set_array([])
  169. def c(x):
  170. return sm.to_rgba(x)
  171. def imshow(im):
  172. plt.close()
  173. plt.tight_layout()
  174. plt.imshow(im)
  175. plt.colorbar(sm, fraction=0.046)
  176. plt.xlim([0, im.shape[0]])
  177. plt.ylim([im.shape[0], 0])
  178. def show_line(img, pred, epoch, writer):
  179. im = img.permute(1, 2, 0)
  180. writer.add_image("ori", im, epoch, dataformats="HWC")
  181. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  182. colors="yellow", width=1)
  183. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  184. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  185. H = pred[1]['wires']
  186. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  187. scores = H["score"][0].cpu().numpy()
  188. for i in range(1, len(lines)):
  189. if (lines[i] == lines[0]).all():
  190. lines = lines[:i]
  191. scores = scores[:i]
  192. break
  193. # postprocess lines to remove overlapped lines
  194. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  195. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  196. for i, t in enumerate([0.8]):
  197. plt.gca().set_axis_off()
  198. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  199. plt.margins(0, 0)
  200. for (a, b), s in zip(nlines, nscores):
  201. if s < t:
  202. continue
  203. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  204. plt.scatter(a[1], a[0], **PLTOPTS)
  205. plt.scatter(b[1], b[0], **PLTOPTS)
  206. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  207. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  208. plt.imshow(im)
  209. plt.tight_layout()
  210. fig = plt.gcf()
  211. fig.canvas.draw()
  212. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  213. fig.canvas.get_width_height()[::-1] + (3,))
  214. plt.close()
  215. img2 = transforms.ToTensor()(image_from_plot)
  216. writer.add_image("output", img2, epoch)
  217. if __name__ == '__main__':
  218. cfg = r'./config/wireframe.yaml'
  219. cfg = read_yaml(cfg)
  220. print(f'cfg:{cfg}')
  221. print(cfg['model']['n_dyn_negl'])
  222. # net = WirepointPredictor()
  223. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  224. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  225. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  226. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
  227. train_collate_fn = utils.collate_fn_wirepoint
  228. data_loader_train = torch.utils.data.DataLoader(
  229. dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
  230. )
  231. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  232. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  233. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  234. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
  235. val_collate_fn = utils.collate_fn_wirepoint
  236. data_loader_val = torch.utils.data.DataLoader(
  237. dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
  238. )
  239. model = linercnn_resnet50_fpn().to(device)
  240. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  241. writer = SummaryWriter(cfg['io']['logdir'])
  242. def move_to_device(data, device):
  243. if isinstance(data, (list, tuple)):
  244. return type(data)(move_to_device(item, device) for item in data)
  245. elif isinstance(data, dict):
  246. return {key: move_to_device(value, device) for key, value in data.items()}
  247. elif isinstance(data, torch.Tensor):
  248. return data.to(device)
  249. else:
  250. return data # 对于非张量类型的数据不做任何改变
  251. # def writer_loss(writer, losses, epoch):
  252. # try:
  253. # for key, value in losses.items():
  254. # if key == 'loss_wirepoint':
  255. # for subdict in losses['loss_wirepoint']['losses']:
  256. # for subkey, subvalue in subdict.items():
  257. # writer.add_scalar(f'loss_wirepoint/{subkey}',
  258. # subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  259. # epoch)
  260. # elif isinstance(value, torch.Tensor):
  261. # writer.add_scalar(key, value.item(), epoch)
  262. # except Exception as e:
  263. # print(f"TensorBoard logging error: {e}")
  264. def writer_loss(writer, losses, epoch):
  265. try:
  266. for key, value in losses.items():
  267. if key == 'loss_wirepoint':
  268. for subdict in losses['loss_wirepoint']['losses']:
  269. for subkey, subvalue in subdict.items():
  270. writer.add_scalar(f'loss/{subkey}',
  271. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  272. epoch)
  273. elif isinstance(value, torch.Tensor):
  274. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  275. except Exception as e:
  276. print(f"TensorBoard logging error: {e}")
  277. for epoch in range(cfg['optim']['max_epoch']):
  278. print(f"epoch:{epoch}")
  279. model.train()
  280. for imgs, targets in data_loader_train:
  281. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  282. # print(losses)
  283. loss = _loss(losses)
  284. optimizer.zero_grad()
  285. loss.backward()
  286. optimizer.step()
  287. writer_loss(writer, losses, epoch)
  288. model.eval()
  289. with torch.no_grad():
  290. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  291. pred = model(move_to_device(imgs, device))
  292. if batch_idx == 0:
  293. show_line(imgs[0], pred, epoch, writer)
  294. break