|
@@ -1,4 +1,3 @@
|
|
|
-
|
|
|
# 根据LCNN写的train 2025/2/7
|
|
|
'''
|
|
|
#!/usr/bin/env python3
|
|
@@ -161,6 +160,8 @@ if __name__ == "__main__":
|
|
|
main()
|
|
|
'''
|
|
|
|
|
|
+
|
|
|
+# 2025/2/9
|
|
|
import os
|
|
|
from typing import Optional, Any
|
|
|
|
|
@@ -178,9 +179,14 @@ import matplotlib as mpl
|
|
|
from skimage import io
|
|
|
|
|
|
from models.line_detect.line_rcnn import linercnn_resnet50_fpn
|
|
|
+from torchvision.utils import draw_bounding_boxes
|
|
|
+from models.wirenet.postprocess import postprocess
|
|
|
+from torchvision import transforms
|
|
|
+from collections import OrderedDict
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
+
|
|
|
def _loss(losses):
|
|
|
total_loss = 0
|
|
|
for i in losses.keys():
|
|
@@ -215,51 +221,50 @@ def imshow(im):
|
|
|
plt.ylim([im.shape[0], 0])
|
|
|
|
|
|
|
|
|
-def _plot_samples(self, i, index, result, targets, prefix):
|
|
|
- fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
|
|
|
- img = io.imread(fn)
|
|
|
- imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
|
|
|
-
|
|
|
- def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
- imshow(img)
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
- break
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
- if not (juncs[0] == 0).all():
|
|
|
- for i, j in enumerate(juncs):
|
|
|
- if i > 0 and (i == juncs[0]).all():
|
|
|
- break
|
|
|
- plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
- if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
- for i, j in enumerate(junts):
|
|
|
- if i > 0 and (i == junts[0]).all():
|
|
|
- break
|
|
|
- plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
- plt.savefig(fn), plt.close()
|
|
|
-
|
|
|
- junc = targets[i]["junc"].cpu().numpy() * 4
|
|
|
- jtyp = targets[i]["jtyp"].cpu().numpy()
|
|
|
- juncs = junc[jtyp == 0]
|
|
|
- junts = junc[jtyp == 1]
|
|
|
- rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
- rjunts = None
|
|
|
- if "junts" in result:
|
|
|
- rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
-
|
|
|
- lpre = targets[i]["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = targets[i]["lpre_label"].cpu().numpy()
|
|
|
- vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
- score = result["score"][i].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
-
|
|
|
- draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
|
|
|
- draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
-
|
|
|
- img = cv2.imread(f"{prefix}_vecl_a.jpg")
|
|
|
- img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
- self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
|
|
|
+def show_line(img, pred, epoch, writer):
|
|
|
+ im = img.permute(1, 2, 0)
|
|
|
+ writer.add_image("ori", im, epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+ writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
+ H = pred[1]['wires']
|
|
|
+ lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
+ scores = H["score"][0].cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for i, t in enumerate([0.8]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t:
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.tight_layout()
|
|
|
+ fig = plt.gcf()
|
|
|
+ fig.canvas.draw()
|
|
|
+ image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
+ fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.close()
|
|
|
+ img2 = transforms.ToTensor()(image_from_plot)
|
|
|
+
|
|
|
+ writer.add_image("output", img2, epoch)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@@ -292,7 +297,6 @@ if __name__ == '__main__':
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
writer = SummaryWriter(cfg['io']['logdir'])
|
|
|
|
|
|
-
|
|
|
def move_to_device(data, device):
|
|
|
if isinstance(data, (list, tuple)):
|
|
|
return type(data)(move_to_device(item, device) for item in data)
|
|
@@ -304,19 +308,30 @@ if __name__ == '__main__':
|
|
|
return data # 对于非张量类型的数据不做任何改变
|
|
|
|
|
|
|
|
|
+ # def writer_loss(writer, losses, epoch):
|
|
|
+ # try:
|
|
|
+ # for key, value in losses.items():
|
|
|
+ # if key == 'loss_wirepoint':
|
|
|
+ # for subdict in losses['loss_wirepoint']['losses']:
|
|
|
+ # for subkey, subvalue in subdict.items():
|
|
|
+ # writer.add_scalar(f'loss_wirepoint/{subkey}',
|
|
|
+ # subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
+ # epoch)
|
|
|
+ # elif isinstance(value, torch.Tensor):
|
|
|
+ # writer.add_scalar(key, value.item(), epoch)
|
|
|
+ # except Exception as e:
|
|
|
+ # print(f"TensorBoard logging error: {e}")
|
|
|
def writer_loss(writer, losses, epoch):
|
|
|
try:
|
|
|
for key, value in losses.items():
|
|
|
if key == 'loss_wirepoint':
|
|
|
- # ?? wirepoint ??????
|
|
|
for subdict in losses['loss_wirepoint']['losses']:
|
|
|
for subkey, subvalue in subdict.items():
|
|
|
- # ?? .item() ?????
|
|
|
- writer.add_scalar(f'loss_wirepoint/{subkey}',
|
|
|
+ writer.add_scalar(f'loss/{subkey}',
|
|
|
subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
epoch)
|
|
|
elif isinstance(value, torch.Tensor):
|
|
|
- writer.add_scalar(key, value.item(), epoch)
|
|
|
+ writer.add_scalar(f'loss/{key}', value.item(), epoch)
|
|
|
except Exception as e:
|
|
|
print(f"TensorBoard logging error: {e}")
|
|
|
|
|
@@ -326,37 +341,20 @@ if __name__ == '__main__':
|
|
|
model.train()
|
|
|
|
|
|
for imgs, targets in data_loader_train:
|
|
|
-
|
|
|
losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
- # print(type(losses))
|
|
|
# print(losses)
|
|
|
loss = _loss(losses)
|
|
|
- # print(loss)
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
writer_loss(writer, losses, epoch)
|
|
|
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
- pred = model(move_to_device(imgs, device))
|
|
|
- # print(f"perd:{pred}")
|
|
|
- break
|
|
|
-
|
|
|
- # print(f"perd:{pred}")
|
|
|
-
|
|
|
- # if batch_idx == 0:
|
|
|
- # viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
|
|
|
- # H = pred["wires"]
|
|
|
- # _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
|
|
|
-
|
|
|
-# imgs, targets = next(iter(data_loader))
|
|
|
-#
|
|
|
-# model.train()
|
|
|
-# pred = model(imgs, targets)
|
|
|
-# print(f'pred:{pred}')
|
|
|
-
|
|
|
-# result, losses = model(imgs, targets)
|
|
|
-# print(f'result:{result}')
|
|
|
-# print(f'pred:{losses}')
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
+ pred = model(move_to_device(imgs, device))
|
|
|
+ if batch_idx == 0:
|
|
|
+ show_line(imgs[0], pred, epoch, writer)
|
|
|
+ break
|
|
|
+
|
|
|
+
|