|
@@ -1,6 +1,7 @@
|
|
|
import os
|
|
|
from typing import Optional, Any
|
|
|
|
|
|
+import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from tensorboardX import SummaryWriter
|
|
@@ -27,9 +28,17 @@ from models.wirenet.wirepoint_dataset import WirePointDataset
|
|
|
from tools import utils
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib as mpl
|
|
|
+from skimage import io
|
|
|
+import os.path as osp
|
|
|
+
|
|
|
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+print(f"Using device: {device}")
|
|
|
+
|
|
|
|
|
|
def non_maximum_suppression(a):
|
|
|
ap = F.max_pool2d(a, 3, stride=1, padding=1)
|
|
@@ -124,7 +133,7 @@ class WirepointRCNN(FasterRCNN):
|
|
|
|
|
|
if wirepoint_head is None:
|
|
|
keypoint_layers = tuple(512 for _ in range(8))
|
|
|
- print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
|
|
|
+ # print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
|
|
|
wirepoint_head = WirepointHead(out_channels, keypoint_layers)
|
|
|
|
|
|
if wirepoint_predictor is None:
|
|
@@ -291,7 +300,7 @@ class WirepointPredictor(nn.Module):
|
|
|
# print(f'out:{out.shape}')
|
|
|
# outputs=merge_features(outputs,100)
|
|
|
batch, channel, row, col = inputs.shape
|
|
|
- print(f'outputs:{inputs.shape}')
|
|
|
+ # print(f'outputs:{inputs.shape}')
|
|
|
# print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
|
|
|
|
|
|
if targets is not None:
|
|
@@ -316,18 +325,18 @@ class WirepointPredictor(nn.Module):
|
|
|
else:
|
|
|
self.training = False
|
|
|
t = {
|
|
|
- "junc_coords": torch.zeros(1, 2),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
+ "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
}
|
|
|
wires_targets = [t for b in range(inputs.size(0))]
|
|
|
|
|
|
wires_meta = {
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
}
|
|
|
|
|
|
T = wires_meta.copy()
|
|
@@ -399,7 +408,6 @@ class WirepointPredictor(nn.Module):
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
- print(f"pstest{ps}")
|
|
|
x = torch.cat([x, f], 1)
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
x = self.fc2(x).flatten()
|
|
@@ -443,6 +451,9 @@ class WirepointPredictor(nn.Module):
|
|
|
xy_ = xy[..., None, :]
|
|
|
del x, y, index
|
|
|
|
|
|
+ # print(f"xy_.is_cuda: {xy_.is_cuda}")
|
|
|
+ # print(f"junc.is_cuda: {junc.is_cuda}")
|
|
|
+
|
|
|
# dist: [N_TYPE, K, N]
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
cost, match = torch.min(dist, -1)
|
|
@@ -555,6 +566,72 @@ def _loss(losses):
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
+cmap = plt.get_cmap("jet")
|
|
|
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
+sm.set_array([])
|
|
|
+
|
|
|
+
|
|
|
+def c(x):
|
|
|
+ return sm.to_rgba(x)
|
|
|
+
|
|
|
+
|
|
|
+def imshow(im):
|
|
|
+ plt.close()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.colorbar(sm, fraction=0.046)
|
|
|
+ plt.xlim([0, im.shape[0]])
|
|
|
+ plt.ylim([im.shape[0], 0])
|
|
|
+
|
|
|
+
|
|
|
+def _plot_samples(self, i, index, result, targets, prefix):
|
|
|
+ fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
|
|
|
+ img = io.imread(fn)
|
|
|
+ imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
|
|
|
+
|
|
|
+ def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
+ imshow(img)
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ break
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
+ if not (juncs[0] == 0).all():
|
|
|
+ for i, j in enumerate(juncs):
|
|
|
+ if i > 0 and (i == juncs[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
+ if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
+ for i, j in enumerate(junts):
|
|
|
+ if i > 0 and (i == junts[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
+ plt.savefig(fn), plt.close()
|
|
|
+
|
|
|
+ junc = targets[i]["junc"].cpu().numpy() * 4
|
|
|
+ jtyp = targets[i]["jtyp"].cpu().numpy()
|
|
|
+ juncs = junc[jtyp == 0]
|
|
|
+ junts = junc[jtyp == 1]
|
|
|
+ rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
+ rjunts = None
|
|
|
+ if "junts" in result:
|
|
|
+ rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
+
|
|
|
+ lpre = targets[i]["lpre"].cpu().numpy() * 4
|
|
|
+ vecl_target = targets[i]["lpre_label"].cpu().numpy()
|
|
|
+ vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
+ score = result["score"][i].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
|
|
|
+ draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
+
|
|
|
+ img = cv2.imread(f"{prefix}_vecl_a.jpg")
|
|
|
+ img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
+ self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
cfg = 'wirenet.yaml'
|
|
|
cfg = read_yaml(cfg)
|
|
@@ -562,15 +639,15 @@ if __name__ == '__main__':
|
|
|
print(cfg['model']['n_dyn_negl'])
|
|
|
# net = WirepointPredictor()
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- device_name = "cuda"
|
|
|
- torch.backends.cudnn.deterministic = True
|
|
|
- torch.cuda.manual_seed(0)
|
|
|
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
|
|
|
- else:
|
|
|
- print("CUDA is not available")
|
|
|
-
|
|
|
- device = torch.device(device_name)
|
|
|
+ # if torch.cuda.is_available():
|
|
|
+ # device_name = "cuda"
|
|
|
+ # torch.backends.cudnn.deterministic = True
|
|
|
+ # torch.cuda.manual_seed(0)
|
|
|
+ # print("Let's use", torch.cuda.device_count(), "GPU(s)!")
|
|
|
+ # else:
|
|
|
+ # print("CUDA is not available")
|
|
|
+ #
|
|
|
+ # device = torch.device(device_name)
|
|
|
|
|
|
dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
@@ -607,17 +684,27 @@ if __name__ == '__main__':
|
|
|
return data # 对于非张量类型的数据不做任何改变
|
|
|
|
|
|
|
|
|
- def writer_loss(writer, losses):
|
|
|
- # 记录每个损失项到TensorBoard
|
|
|
- for key, value in losses.items():
|
|
|
- if isinstance(value, dict): # 如果value本身也是一个字典(例如'loss_wirepoint')
|
|
|
- for subkey, subvalue in value['losses'][0].items():
|
|
|
- writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
|
|
|
- else:
|
|
|
- writer.add_scalar(key, value.item(), epoch)
|
|
|
+ def writer_loss(writer, losses, epoch):
|
|
|
+ # ??????
|
|
|
+ try:
|
|
|
+ for key, value in losses.items():
|
|
|
+ if key == 'loss_wirepoint':
|
|
|
+ # ?? wirepoint ??????
|
|
|
+ for subdict in losses['loss_wirepoint']['losses']:
|
|
|
+ for subkey, subvalue in subdict.items():
|
|
|
+ # ?? .item() ?????
|
|
|
+ writer.add_scalar(f'loss_wirepoint/{subkey}',
|
|
|
+ subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
+ epoch)
|
|
|
+ elif isinstance(value, torch.Tensor):
|
|
|
+ # ????????
|
|
|
+ writer.add_scalar(key, value.item(), epoch)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"TensorBoard logging error: {e}")
|
|
|
|
|
|
|
|
|
for epoch in range(cfg['optim']['max_epoch']):
|
|
|
+ print(f"epoch:{epoch}")
|
|
|
model.train()
|
|
|
|
|
|
for imgs, targets in data_loader_train:
|
|
@@ -627,14 +714,18 @@ if __name__ == '__main__':
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
- writer_loss(writer, losses)
|
|
|
+ writer_loss(writer, losses, epoch)
|
|
|
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
- for imgs, targets in data_loader_val:
|
|
|
- print(111)
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
pred = model(move_to_device(imgs, device))
|
|
|
- print(f"pred:{pred}")
|
|
|
+ print(f"perd:{pred}")
|
|
|
+
|
|
|
+ # if batch_idx == 0:
|
|
|
+ # viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
|
|
|
+ # H = pred["wires"]
|
|
|
+ # _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
|
|
|
|
|
|
# imgs, targets = next(iter(data_loader))
|
|
|
#
|
|
@@ -645,83 +736,3 @@ if __name__ == '__main__':
|
|
|
# result, losses = model(imgs, targets)
|
|
|
# print(f'result:{result}')
|
|
|
# print(f'pred:{losses}')
|
|
|
-'''
|
|
|
-########### predict#############
|
|
|
-
|
|
|
- img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
|
|
|
- transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
|
|
|
- img = read_image(img_path)
|
|
|
- img = transforms(img)
|
|
|
-
|
|
|
- img = torch.ones((2, 3, 512, 512))
|
|
|
- # print(f'img shape:{img.shape}')
|
|
|
- model.eval()
|
|
|
- onnx_file_path = "./wirenet.onnx"
|
|
|
-
|
|
|
- # 导出模型为ONNX格式
|
|
|
- # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
|
|
|
- # output_names=['output'])
|
|
|
- # torch.save(model,'./wirenet.pt')
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- # 5. 指定输出的 ONNX 文件名
|
|
|
- # onnx_file_path = "./wirepoint_rcnn.onnx"
|
|
|
-
|
|
|
- # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
|
|
|
- img = [torch.ones((3, 800, 800))] # 示例输入图像大小为 800x800,3个通道
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- # 指定输出的 ONNX 文件名
|
|
|
- # onnx_file_path = "./mask_rcnn.onnx"
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- # model_scripted = torch.jit.script(model)
|
|
|
- # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
|
|
|
- # output_names=["output"])
|
|
|
- #
|
|
|
- # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")
|
|
|
-
|
|
|
- pred=model(img)
|
|
|
- #
|
|
|
- print(f'pred:{pred}')
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-################################################## end predict
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-########## traing ###################################
|
|
|
- # imgs, targets = next(iter(data_loader))
|
|
|
-
|
|
|
- # model.train()
|
|
|
- # pred = model(imgs, targets)
|
|
|
-
|
|
|
- # class WrapperModule(torch.nn.Module):
|
|
|
- # def __init__(self, model):
|
|
|
- # super(WrapperModule, self).__init__()
|
|
|
- # self.model = model
|
|
|
- #
|
|
|
- # def forward(self,img, targets):
|
|
|
- # # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
|
|
|
- # return self.model(img,targets)
|
|
|
-
|
|
|
- # torch.save(model.state_dict(),'./wire.pt')
|
|
|
- # 包装原始模型
|
|
|
- # wrapped_model = WrapperModule(model)
|
|
|
- # # model_scripted = torch.jit.trace(wrapped_model,img)
|
|
|
- # writer = SummaryWriter('./')
|
|
|
- # writer.add_graph(wrapped_model, (imgs,targets))
|
|
|
- # writer.close()
|
|
|
-
|
|
|
-
|
|
|
- #
|
|
|
- # print(f'pred:{pred}')
|
|
|
-########## end traing ###################################
|
|
|
- # for imgs,targets in data_loader:
|
|
|
- # print(f'imgs:{imgs}')
|
|
|
- # print(f'targets:{targets}')
|
|
|
-'''
|