train——line_rcnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # 根据LCNN写的train 2025/2/7
  2. '''
  3. #!/usr/bin/env python3
  4. import datetime
  5. import glob
  6. import os
  7. import os.path as osp
  8. import platform
  9. import pprint
  10. import random
  11. import shlex
  12. import shutil
  13. import subprocess
  14. import sys
  15. import numpy as np
  16. import torch
  17. import torchvision
  18. import yaml
  19. import lcnn
  20. from lcnn.config import C, M
  21. from lcnn.datasets import WireframeDataset, collate
  22. from lcnn.models.line_vectorizer import LineVectorizer
  23. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  24. from torchvision.models import resnet50
  25. from models.line_detect.line_rcnn import linercnn_resnet50_fpn
  26. def main():
  27. # 训练配置参数
  28. config = {
  29. # 数据集配置
  30. 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
  31. 'config_file': 'config/wireframe.yaml', # 配置文件路径
  32. # GPU配置
  33. 'devices': '0', # 使用的GPU设备
  34. 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
  35. # 预训练模型路径
  36. # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
  37. }
  38. # 更新配置
  39. C.update(C.from_yaml(filename=config['config_file']))
  40. M.update(C.model)
  41. # 设置随机数种子
  42. random.seed(0)
  43. np.random.seed(0)
  44. torch.manual_seed(0)
  45. # 设备配置
  46. device_name = "cpu"
  47. os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
  48. if torch.cuda.is_available():
  49. device_name = "cuda"
  50. torch.backends.cudnn.deterministic = True
  51. torch.cuda.manual_seed(0)
  52. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  53. else:
  54. print("CUDA is not available")
  55. device = torch.device(device_name)
  56. # 数据加载
  57. kwargs = {
  58. "collate_fn": collate,
  59. "num_workers": C.io.num_workers if os.name != "nt" else 0,
  60. "pin_memory": True,
  61. }
  62. train_loader = torch.utils.data.DataLoader(
  63. WireframeDataset(config['datadir'], dataset_type="train"),
  64. shuffle=True,
  65. batch_size=M.batch_size,
  66. **kwargs,
  67. )
  68. val_loader = torch.utils.data.DataLoader(
  69. WireframeDataset(config['datadir'], dataset_type="val"),
  70. shuffle=False,
  71. batch_size=M.batch_size_eval,
  72. **kwargs,
  73. )
  74. model = linercnn_resnet50_fpn().to(device)
  75. # 加载预训练权重
  76. try:
  77. # 加载模型权重
  78. checkpoint = torch.load(config['pretrained_model'], map_location=device)
  79. # 根据实际的检查点结构选择加载方式
  80. if 'model_state_dict' in checkpoint:
  81. # 如果是完整的检查点
  82. model.load_state_dict(checkpoint['model_state_dict'])
  83. elif 'state_dict' in checkpoint:
  84. # 如果是只有状态字典的检查点
  85. model.load_state_dict(checkpoint['state_dict'])
  86. else:
  87. # 直接加载权重字典
  88. model.load_state_dict(checkpoint)
  89. print("Successfully loaded pre-trained model weights.")
  90. except Exception as e:
  91. print(f"Error loading model weights: {e}")
  92. # 优化器配置
  93. if C.optim.name == "Adam":
  94. optim = torch.optim.Adam(
  95. filter(lambda p: p.requires_grad, model.parameters()),
  96. lr=C.optim.lr,
  97. weight_decay=C.optim.weight_decay,
  98. amsgrad=C.optim.amsgrad,
  99. )
  100. elif C.optim.name == "SGD":
  101. optim = torch.optim.SGD(
  102. filter(lambda p: p.requires_grad, model.parameters()),
  103. lr=C.optim.lr,
  104. weight_decay=C.optim.weight_decay,
  105. momentum=C.optim.momentum,
  106. )
  107. else:
  108. raise NotImplementedError
  109. # 输出目录
  110. outdir = osp.join(
  111. osp.expanduser(C.io.logdir),
  112. f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
  113. )
  114. os.makedirs(outdir, exist_ok=True)
  115. try:
  116. trainer = lcnn.trainer.Trainer(
  117. device=device,
  118. model=model,
  119. optimizer=optim,
  120. train_loader=train_loader,
  121. val_loader=val_loader,
  122. out=outdir,
  123. )
  124. print("Starting training...")
  125. trainer.train()
  126. print("Training completed.")
  127. except BaseException:
  128. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  129. shutil.rmtree(outdir)
  130. raise
  131. if __name__ == "__main__":
  132. main()
  133. '''
  134. import os
  135. from typing import Optional, Any
  136. import cv2
  137. import numpy as np
  138. import torch
  139. from models.config.config_tool import read_yaml
  140. from models.line_detect.dataset_LD import WirePointDataset
  141. from tools import utils
  142. from torch.utils.tensorboard import SummaryWriter
  143. import matplotlib.pyplot as plt
  144. import matplotlib as mpl
  145. from skimage import io
  146. from models.line_detect.line_rcnn import linercnn_resnet50_fpn
  147. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  148. def _loss(losses):
  149. total_loss = 0
  150. for i in losses.keys():
  151. if i != "loss_wirepoint":
  152. total_loss += losses[i]
  153. else:
  154. loss_labels = losses[i]["losses"]
  155. loss_labels_k = list(loss_labels[0].keys())
  156. for j, name in enumerate(loss_labels_k):
  157. loss = loss_labels[0][name].mean()
  158. total_loss += loss
  159. return total_loss
  160. cmap = plt.get_cmap("jet")
  161. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  162. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  163. sm.set_array([])
  164. def c(x):
  165. return sm.to_rgba(x)
  166. def imshow(im):
  167. plt.close()
  168. plt.tight_layout()
  169. plt.imshow(im)
  170. plt.colorbar(sm, fraction=0.046)
  171. plt.xlim([0, im.shape[0]])
  172. plt.ylim([im.shape[0], 0])
  173. def _plot_samples(self, i, index, result, targets, prefix):
  174. fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
  175. img = io.imread(fn)
  176. imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
  177. def draw_vecl(lines, sline, juncs, junts, fn):
  178. imshow(img)
  179. if len(lines) > 0 and not (lines[0] == 0).all():
  180. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  181. if i > 0 and (lines[i] == lines[0]).all():
  182. break
  183. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  184. if not (juncs[0] == 0).all():
  185. for i, j in enumerate(juncs):
  186. if i > 0 and (i == juncs[0]).all():
  187. break
  188. plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
  189. if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
  190. for i, j in enumerate(junts):
  191. if i > 0 and (i == junts[0]).all():
  192. break
  193. plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
  194. plt.savefig(fn), plt.close()
  195. junc = targets[i]["junc"].cpu().numpy() * 4
  196. jtyp = targets[i]["jtyp"].cpu().numpy()
  197. juncs = junc[jtyp == 0]
  198. junts = junc[jtyp == 1]
  199. rjuncs = result["juncs"][i].cpu().numpy() * 4
  200. rjunts = None
  201. if "junts" in result:
  202. rjunts = result["junts"][i].cpu().numpy() * 4
  203. lpre = targets[i]["lpre"].cpu().numpy() * 4
  204. vecl_target = targets[i]["lpre_label"].cpu().numpy()
  205. vecl_result = result["lines"][i].cpu().numpy() * 4
  206. score = result["score"][i].cpu().numpy()
  207. lpre = lpre[vecl_target == 1]
  208. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
  209. draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
  210. img = cv2.imread(f"{prefix}_vecl_a.jpg")
  211. img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
  212. self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
  213. if __name__ == '__main__':
  214. cfg = r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml'
  215. cfg = read_yaml(cfg)
  216. print(f'cfg:{cfg}')
  217. print(cfg['model']['n_dyn_negl'])
  218. # net = WirepointPredictor()
  219. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  220. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  221. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  222. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
  223. train_collate_fn = utils.collate_fn_wirepoint
  224. data_loader_train = torch.utils.data.DataLoader(
  225. dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
  226. )
  227. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  228. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  229. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  230. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
  231. val_collate_fn = utils.collate_fn_wirepoint
  232. data_loader_val = torch.utils.data.DataLoader(
  233. dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
  234. )
  235. model = linercnn_resnet50_fpn().to(device)
  236. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  237. writer = SummaryWriter(cfg['io']['logdir'])
  238. def move_to_device(data, device):
  239. if isinstance(data, (list, tuple)):
  240. return type(data)(move_to_device(item, device) for item in data)
  241. elif isinstance(data, dict):
  242. return {key: move_to_device(value, device) for key, value in data.items()}
  243. elif isinstance(data, torch.Tensor):
  244. return data.to(device)
  245. else:
  246. return data # 对于非张量类型的数据不做任何改变
  247. def writer_loss(writer, losses, epoch):
  248. try:
  249. for key, value in losses.items():
  250. if key == 'loss_wirepoint':
  251. # ?? wirepoint ??????
  252. for subdict in losses['loss_wirepoint']['losses']:
  253. for subkey, subvalue in subdict.items():
  254. # ?? .item() ?????
  255. writer.add_scalar(f'loss_wirepoint/{subkey}',
  256. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  257. epoch)
  258. elif isinstance(value, torch.Tensor):
  259. writer.add_scalar(key, value.item(), epoch)
  260. except Exception as e:
  261. print(f"TensorBoard logging error: {e}")
  262. for epoch in range(cfg['optim']['max_epoch']):
  263. print(f"epoch:{epoch}")
  264. model.train()
  265. for imgs, targets in data_loader_train:
  266. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  267. # print(type(losses))
  268. # print(losses)
  269. loss = _loss(losses)
  270. # print(loss)
  271. optimizer.zero_grad()
  272. loss.backward()
  273. optimizer.step()
  274. writer_loss(writer, losses, epoch)
  275. model.eval()
  276. with torch.no_grad():
  277. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  278. pred = model(move_to_device(imgs, device))
  279. # print(f"perd:{pred}")
  280. break
  281. # print(f"perd:{pred}")
  282. # if batch_idx == 0:
  283. # viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
  284. # H = pred["wires"]
  285. # _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
  286. # imgs, targets = next(iter(data_loader))
  287. #
  288. # model.train()
  289. # pred = model(imgs, targets)
  290. # print(f'pred:{pred}')
  291. # result, losses = model(imgs, targets)
  292. # print(f'result:{result}')
  293. # print(f'pred:{losses}')