|
@@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|
|
# from torchinfo import summary
|
|
|
from torchvision.io import read_image
|
|
|
from torchvision.models import resnet50, ResNet50_Weights
|
|
|
+from torchvision.models import resnet18, ResNet18_Weights
|
|
|
from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
|
|
|
from torchvision.models.detection._utils import overwrite_eps
|
|
|
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
@@ -522,26 +523,55 @@ class WirepointPredictor(nn.Module):
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
return line, label.float(), feat, jcs
|
|
|
|
|
|
+# def wirepointrcnn_resnet50_fpn(
|
|
|
+# *,
|
|
|
+# weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
|
|
|
+# progress: bool = True,
|
|
|
+# num_classes: Optional[int] = None,
|
|
|
+# num_keypoints: Optional[int] = None,
|
|
|
+# weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
|
|
+# trainable_backbone_layers: Optional[int] = None,
|
|
|
+# **kwargs: Any,
|
|
|
+# ) -> WirepointRCNN:
|
|
|
+# weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
|
|
|
+# weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
+#
|
|
|
+# is_trained = weights is not None or weights_backbone is not None
|
|
|
+# trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
|
|
+#
|
|
|
+# norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
+#
|
|
|
+# backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
|
|
+# backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
+# model = WirepointRCNN(backbone, num_classes=5, **kwargs)
|
|
|
+#
|
|
|
+# if weights is not None:
|
|
|
+# model.load_state_dict(weights.get_state_dict(progress=progress))
|
|
|
+# if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
|
|
|
+# overwrite_eps(model, 0.0)
|
|
|
+#
|
|
|
+# return model
|
|
|
|
|
|
-def wirepointrcnn_resnet50_fpn(
|
|
|
+
|
|
|
+def wirepointrcnn_resnet18_fpn(
|
|
|
*,
|
|
|
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
|
|
|
progress: bool = True,
|
|
|
num_classes: Optional[int] = None,
|
|
|
num_keypoints: Optional[int] = None,
|
|
|
- weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
|
|
+ weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
|
|
|
trainable_backbone_layers: Optional[int] = None,
|
|
|
**kwargs: Any,
|
|
|
) -> WirepointRCNN:
|
|
|
weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
|
|
|
- weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
+ weights_backbone = ResNet18_Weights.verify(weights_backbone)
|
|
|
|
|
|
is_trained = weights is not None or weights_backbone is not None
|
|
|
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
|
|
|
|
|
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
|
|
|
- backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
|
|
+ backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
|
|
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
model = WirepointRCNN(backbone, num_classes=5, **kwargs)
|
|
|
|
|
@@ -577,27 +607,128 @@ sm.set_array([])
|
|
|
def c(x):
|
|
|
return sm.to_rgba(x)
|
|
|
|
|
|
+
|
|
|
+def imshow(im):
|
|
|
+ plt.close()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.colorbar(sm, fraction=0.046)
|
|
|
+ plt.xlim([0, im.shape[0]])
|
|
|
+ plt.ylim([im.shape[0], 0])
|
|
|
+ # plt.show()
|
|
|
+
|
|
|
+
|
|
|
+# def _plot_samples(img, i, result, prefix, epoch):
|
|
|
+# print(f"prefix:{prefix}")
|
|
|
+# def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
+# directory = os.path.dirname(fn)
|
|
|
+# if not os.path.exists(directory):
|
|
|
+# os.makedirs(directory)
|
|
|
+# imshow(img.permute(1, 2, 0))
|
|
|
+# if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+# for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+# if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+# break
|
|
|
+# plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
+# if not (juncs[0] == 0).all():
|
|
|
+# for i, j in enumerate(juncs):
|
|
|
+# if i > 0 and (i == juncs[0]).all():
|
|
|
+# break
|
|
|
+# plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
+# if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
+# for i, j in enumerate(junts):
|
|
|
+# if i > 0 and (i == junts[0]).all():
|
|
|
+# break
|
|
|
+# plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
+# plt.savefig(fn), plt.close()
|
|
|
+#
|
|
|
+# rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
+# rjunts = None
|
|
|
+# if "junts" in result:
|
|
|
+# rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
+#
|
|
|
+# vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
+# score = result["score"][i].cpu().numpy()
|
|
|
#
|
|
|
-# def imshow(im):
|
|
|
-# plt.close()
|
|
|
-# plt.tight_layout()
|
|
|
-# plt.imshow(im)
|
|
|
-# plt.colorbar(sm, fraction=0.046)
|
|
|
-# plt.xlim([0, im.shape[0]])
|
|
|
-# plt.ylim([im.shape[0], 0])
|
|
|
-# # plt.show()
|
|
|
+# draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
+#
|
|
|
+# img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
+# writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
|
|
|
+
|
|
|
+def _plot_samples(img, i, result, prefix, epoch, writer):
|
|
|
+ # print(f"prefix:{prefix}")
|
|
|
+
|
|
|
+ def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
+ # 确保目录存在
|
|
|
+ directory = os.path.dirname(fn)
|
|
|
+ if not os.path.exists(directory):
|
|
|
+ os.makedirs(directory)
|
|
|
+
|
|
|
+ # 绘制图像
|
|
|
+ plt.figure()
|
|
|
+ plt.imshow(img.permute(1, 2, 0).cpu().numpy())
|
|
|
+ plt.axis('off') # 可选:关闭坐标轴
|
|
|
+
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for idx, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if idx > 0 and (lines[idx] == lines[0]).all():
|
|
|
+ break
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
|
|
|
+
|
|
|
+ if not (juncs[0] == 0).all():
|
|
|
+ for idx, j in enumerate(juncs):
|
|
|
+ if idx > 0 and (j == juncs[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
|
|
|
+
|
|
|
+ if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
+ for idx, j in enumerate(junts):
|
|
|
+ if idx > 0 and (j == junts[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
|
|
|
+
|
|
|
+ # plt.show()
|
|
|
+
|
|
|
+ # 将matplotlib图像转换为numpy数组
|
|
|
+ plt.tight_layout()
|
|
|
+ fig = plt.gcf()
|
|
|
+ fig.canvas.draw()
|
|
|
+ image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
+ fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.close()
|
|
|
|
|
|
+ return image_from_plot
|
|
|
|
|
|
-def show_line(img, pred, epoch, write):
|
|
|
- im = img.permute(1, 2, 0)
|
|
|
- writer.add_image("ori", im, epoch, dataformats="HWC")
|
|
|
+ # 获取结果数据并转换为numpy数组
|
|
|
+ rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
+ rjunts = None
|
|
|
+ if "junts" in result:
|
|
|
+ rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
|
|
|
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
|
|
|
- colors="yellow", width=1)
|
|
|
- writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+ vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
+ score = result["score"][i].cpu().numpy()
|
|
|
|
|
|
+ # 调用绘图函数并获取图像
|
|
|
+ image_path = f"{prefix}_vecl_b.jpg"
|
|
|
+ image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
|
|
|
+
|
|
|
+ # 将numpy数组转换为torch tensor,并写入TensorBoard
|
|
|
+ image_tensor = transforms.ToTensor()(image_array)
|
|
|
+ writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
|
|
|
+ writer.add_image(f'ori_epoch', img, global_step=epoch)
|
|
|
+
|
|
|
+
|
|
|
+def show_line(img, pred, prefix, epoch, write):
|
|
|
+ fn = f"{prefix}_line.jpg"
|
|
|
+ directory = os.path.dirname(fn)
|
|
|
+ if not os.path.exists(directory):
|
|
|
+ os.makedirs(directory)
|
|
|
+ print(fn)
|
|
|
PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
- H = pred[1]['wires']
|
|
|
+ H = pred
|
|
|
+
|
|
|
+ im = img.permute(1, 2, 0)
|
|
|
+
|
|
|
lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
scores = H["score"][0].cpu().numpy()
|
|
|
for i in range(1, len(lines)):
|
|
@@ -623,14 +754,14 @@ def show_line(img, pred, epoch, write):
|
|
|
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
plt.imshow(im)
|
|
|
- plt.tight_layout()
|
|
|
- fig = plt.gcf()
|
|
|
- fig.canvas.draw()
|
|
|
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
- fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.savefig(fn, bbox_inches="tight")
|
|
|
+ plt.show()
|
|
|
plt.close()
|
|
|
- img2 = transforms.ToTensor()(image_from_plot)
|
|
|
|
|
|
+ img2 = cv2.imread(fn) # 预测图
|
|
|
+ # img1 = im.resize(img2.shape) # 原图
|
|
|
+
|
|
|
+ # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
|
|
|
writer.add_image("output", img2, epoch)
|
|
|
|
|
|
|
|
@@ -669,7 +800,11 @@ if __name__ == '__main__':
|
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
|
|
|
)
|
|
|
|
|
|
- model = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
+ model = wirepointrcnn_resnet18_fpn().to(device)
|
|
|
+ # print(model)
|
|
|
+
|
|
|
+ # model1 = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
+ # print(model1)
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
writer = SummaryWriter(cfg['io']['logdir'])
|
|
@@ -709,23 +844,27 @@ if __name__ == '__main__':
|
|
|
print(f"epoch:{epoch}")
|
|
|
model.train()
|
|
|
|
|
|
- # for imgs, targets in data_loader_train:
|
|
|
- # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
- # loss = _loss(losses)
|
|
|
- # print(loss)
|
|
|
- # optimizer.zero_grad()
|
|
|
- # loss.backward()
|
|
|
- # optimizer.step()
|
|
|
- # writer_loss(writer, losses, epoch)
|
|
|
-
|
|
|
- model.eval()
|
|
|
- with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
- pred = model(move_to_device(imgs, device)) # # pred[0].keys() ['boxes', 'labels', 'scores']
|
|
|
- # print(f"pred:{pred}")
|
|
|
-
|
|
|
- if batch_idx == 0:
|
|
|
- show_line(imgs[0], pred, epoch, writer)
|
|
|
+ for imgs, targets in data_loader_train:
|
|
|
+ losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+ loss = _loss(losses)
|
|
|
+ print(f"loss:{loss}")
|
|
|
+ optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ writer_loss(writer, losses, epoch)
|
|
|
+
|
|
|
+ # model.eval()
|
|
|
+ # with torch.no_grad():
|
|
|
+ # for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
+ # pred = model(move_to_device(imgs, device))
|
|
|
+ # print(f"pred:{pred}")
|
|
|
+ #
|
|
|
+ # if batch_idx == 0:
|
|
|
+ # result = pred[1]['wires'] # pred[0].keys() ['boxes', 'labels', 'scores']
|
|
|
+ # print(imgs[0].shape) # [3,512,512]
|
|
|
+ # # imshow(imgs[0].permute(1, 2, 0)) # 改为(512, 512, 3)
|
|
|
+ # _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
|
|
|
+ # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|