|
@@ -1,6 +1,7 @@
|
|
|
import os
|
|
|
from typing import Optional, Any
|
|
|
|
|
|
+import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from tensorboardX import SummaryWriter
|
|
@@ -27,9 +28,17 @@ from models.wirenet.wirepoint_dataset import WirePointDataset
|
|
|
from tools import utils
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib as mpl
|
|
|
+from skimage import io
|
|
|
+import os.path as osp
|
|
|
+
|
|
|
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+print(f"Using device: {device}")
|
|
|
+
|
|
|
|
|
|
def non_maximum_suppression(a):
|
|
|
ap = F.max_pool2d(a, 3, stride=1, padding=1)
|
|
@@ -124,7 +133,7 @@ class WirepointRCNN(FasterRCNN):
|
|
|
|
|
|
if wirepoint_head is None:
|
|
|
keypoint_layers = tuple(512 for _ in range(8))
|
|
|
- print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
|
|
|
+
|
|
|
wirepoint_head = WirepointHead(out_channels, keypoint_layers)
|
|
|
|
|
|
if wirepoint_predictor is None:
|
|
@@ -291,7 +300,7 @@ class WirepointPredictor(nn.Module):
|
|
|
|
|
|
|
|
|
batch, channel, row, col = inputs.shape
|
|
|
- print(f'outputs:{inputs.shape}')
|
|
|
+
|
|
|
|
|
|
|
|
|
if targets is not None:
|
|
@@ -316,18 +325,18 @@ class WirepointPredictor(nn.Module):
|
|
|
else:
|
|
|
self.training = False
|
|
|
t = {
|
|
|
- "junc_coords": torch.zeros(1, 2),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
+ "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
}
|
|
|
wires_targets = [t for b in range(inputs.size(0))]
|
|
|
|
|
|
wires_meta = {
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
}
|
|
|
|
|
|
T = wires_meta.copy()
|
|
@@ -399,7 +408,6 @@ class WirepointPredictor(nn.Module):
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
- print(f"pstest{ps}")
|
|
|
x = torch.cat([x, f], 1)
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
x = self.fc2(x).flatten()
|
|
@@ -443,6 +451,9 @@ class WirepointPredictor(nn.Module):
|
|
|
xy_ = xy[..., None, :]
|
|
|
del x, y, index
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
cost, match = torch.min(dist, -1)
|
|
@@ -555,6 +566,72 @@ def _loss(losses):
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
+cmap = plt.get_cmap("jet")
|
|
|
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
+sm.set_array([])
|
|
|
+
|
|
|
+
|
|
|
+def c(x):
|
|
|
+ return sm.to_rgba(x)
|
|
|
+
|
|
|
+
|
|
|
+def imshow(im):
|
|
|
+ plt.close()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.colorbar(sm, fraction=0.046)
|
|
|
+ plt.xlim([0, im.shape[0]])
|
|
|
+ plt.ylim([im.shape[0], 0])
|
|
|
+
|
|
|
+
|
|
|
+def _plot_samples(self, i, index, result, targets, prefix):
|
|
|
+ fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
|
|
|
+ img = io.imread(fn)
|
|
|
+ imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
|
|
|
+
|
|
|
+ def draw_vecl(lines, sline, juncs, junts, fn):
|
|
|
+ imshow(img)
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ break
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
|
|
|
+ if not (juncs[0] == 0).all():
|
|
|
+ for i, j in enumerate(juncs):
|
|
|
+ if i > 0 and (i == juncs[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
|
|
|
+ if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
|
|
|
+ for i, j in enumerate(junts):
|
|
|
+ if i > 0 and (i == junts[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
|
|
|
+ plt.savefig(fn), plt.close()
|
|
|
+
|
|
|
+ junc = targets[i]["junc"].cpu().numpy() * 4
|
|
|
+ jtyp = targets[i]["jtyp"].cpu().numpy()
|
|
|
+ juncs = junc[jtyp == 0]
|
|
|
+ junts = junc[jtyp == 1]
|
|
|
+ rjuncs = result["juncs"][i].cpu().numpy() * 4
|
|
|
+ rjunts = None
|
|
|
+ if "junts" in result:
|
|
|
+ rjunts = result["junts"][i].cpu().numpy() * 4
|
|
|
+
|
|
|
+ lpre = targets[i]["lpre"].cpu().numpy() * 4
|
|
|
+ vecl_target = targets[i]["lpre_label"].cpu().numpy()
|
|
|
+ vecl_result = result["lines"][i].cpu().numpy() * 4
|
|
|
+ score = result["score"][i].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
|
|
|
+ draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
|
|
|
+
|
|
|
+ img = cv2.imread(f"{prefix}_vecl_a.jpg")
|
|
|
+ img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
|
|
|
+ self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
cfg = 'wirenet.yaml'
|
|
|
cfg = read_yaml(cfg)
|
|
@@ -562,15 +639,15 @@ if __name__ == '__main__':
|
|
|
print(cfg['model']['n_dyn_negl'])
|
|
|
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- device_name = "cuda"
|
|
|
- torch.backends.cudnn.deterministic = True
|
|
|
- torch.cuda.manual_seed(0)
|
|
|
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
|
|
|
- else:
|
|
|
- print("CUDA is not available")
|
|
|
-
|
|
|
- device = torch.device(device_name)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
@@ -607,17 +684,27 @@ if __name__ == '__main__':
|
|
|
return data
|
|
|
|
|
|
|
|
|
- def writer_loss(writer, losses):
|
|
|
-
|
|
|
- for key, value in losses.items():
|
|
|
- if isinstance(value, dict):
|
|
|
- for subkey, subvalue in value['losses'][0].items():
|
|
|
- writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
|
|
|
- else:
|
|
|
- writer.add_scalar(key, value.item(), epoch)
|
|
|
+ def writer_loss(writer, losses, epoch):
|
|
|
+
|
|
|
+ try:
|
|
|
+ for key, value in losses.items():
|
|
|
+ if key == 'loss_wirepoint':
|
|
|
+
|
|
|
+ for subdict in losses['loss_wirepoint']['losses']:
|
|
|
+ for subkey, subvalue in subdict.items():
|
|
|
+
|
|
|
+ writer.add_scalar(f'loss_wirepoint/{subkey}',
|
|
|
+ subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
+ epoch)
|
|
|
+ elif isinstance(value, torch.Tensor):
|
|
|
+
|
|
|
+ writer.add_scalar(key, value.item(), epoch)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"TensorBoard logging error: {e}")
|
|
|
|
|
|
|
|
|
for epoch in range(cfg['optim']['max_epoch']):
|
|
|
+ print(f"epoch:{epoch}")
|
|
|
model.train()
|
|
|
|
|
|
for imgs, targets in data_loader_train:
|
|
@@ -627,14 +714,18 @@ if __name__ == '__main__':
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
- writer_loss(writer, losses)
|
|
|
+ writer_loss(writer, losses, epoch)
|
|
|
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
- for imgs, targets in data_loader_val:
|
|
|
- print(111)
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
pred = model(move_to_device(imgs, device))
|
|
|
- print(f"pred:{pred}")
|
|
|
+ print(f"perd:{pred}")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|
|
|
|
|
@@ -645,83 +736,3 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
-'''
|
|
|
-
|
|
|
-
|
|
|
- img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
|
|
|
- transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
|
|
|
- img = read_image(img_path)
|
|
|
- img = transforms(img)
|
|
|
-
|
|
|
- img = torch.ones((2, 3, 512, 512))
|
|
|
-
|
|
|
- model.eval()
|
|
|
- onnx_file_path = "./wirenet.onnx"
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- img = [torch.ones((3, 800, 800))]
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- pred=model(img)
|
|
|
-
|
|
|
- print(f'pred:{pred}')
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-'''
|