|
@@ -0,0 +1,136 @@
|
|
|
+import time
|
|
|
+import skimage
|
|
|
+from models.line_detect.postprocess import show_predict
|
|
|
+import os
|
|
|
+import torch
|
|
|
+from PIL import Image
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
+from torchvision import transforms
|
|
|
+from rtree import index
|
|
|
+import multiprocessing as mp
|
|
|
+
|
|
|
+
|
|
|
+mp.set_start_method('spawn', force=True)
|
|
|
+
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+
|
|
|
+def load_best_model(model, save_path, device):
|
|
|
+ if os.path.exists(save_path):
|
|
|
+ checkpoint = torch.load(save_path, map_location=device)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ epoch = checkpoint['epoch']
|
|
|
+ loss = checkpoint['loss']
|
|
|
+ print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
|
|
|
+ else:
|
|
|
+ print(f"No saved model found at {save_path}")
|
|
|
+ return model
|
|
|
+
|
|
|
+def process_box(box, lines, scores):
|
|
|
+ valid_lines = []
|
|
|
+ valid_scores = []
|
|
|
+
|
|
|
+ for i in box:
|
|
|
+ best_line = None
|
|
|
+ max_length = 0.0
|
|
|
+
|
|
|
+
|
|
|
+ for j in range(lines.shape[1]):
|
|
|
+ line_j = lines[0, j].cpu().numpy() / 128 * 512
|
|
|
+
|
|
|
+ if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
|
|
|
+ line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
|
|
|
+ line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
|
|
|
+ line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
|
|
|
+
|
|
|
+ length = np.linalg.norm(line_j[0] - line_j[1])
|
|
|
+ if length > max_length:
|
|
|
+ best_line = line_j
|
|
|
+ max_length = length
|
|
|
+
|
|
|
+
|
|
|
+ if best_line is not None:
|
|
|
+ valid_lines.append(best_line)
|
|
|
+ valid_scores.append(max_length)
|
|
|
+
|
|
|
+ return valid_lines, valid_scores
|
|
|
+
|
|
|
+
|
|
|
+def box_line_optimized_parallel(pred):
|
|
|
+
|
|
|
+ lines = pred[-1]['wires']['lines']
|
|
|
+ scores = pred[-1]['wires']['score'][0]
|
|
|
+
|
|
|
+
|
|
|
+ filtered_pred = []
|
|
|
+
|
|
|
+
|
|
|
+ boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]
|
|
|
+ num_processes = min(mp.cpu_count(), len(boxes))
|
|
|
+
|
|
|
+ with mp.Pool(processes=num_processes) as pool:
|
|
|
+ results = pool.starmap(
|
|
|
+ process_box,
|
|
|
+ [(box, lines, scores) for box in boxes]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ for idx_box, (valid_lines, valid_scores) in enumerate(results):
|
|
|
+ if valid_lines:
|
|
|
+ pred[idx_box]['line'] = torch.tensor(valid_lines)
|
|
|
+ pred[idx_box]['line_score'] = torch.tensor(valid_scores)
|
|
|
+ filtered_pred.append(pred[idx_box])
|
|
|
+
|
|
|
+ return filtered_pred
|
|
|
+
|
|
|
+
|
|
|
+def predict(pt_path, model, img):
|
|
|
+ model = load_best_model(model, pt_path, device)
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ if isinstance(img, str):
|
|
|
+ img = Image.open(img).convert("RGB")
|
|
|
+
|
|
|
+ transform = transforms.ToTensor()
|
|
|
+ img_tensor = transform(img)
|
|
|
+ im = img_tensor.permute(1, 2, 0)
|
|
|
+ im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))
|
|
|
+ img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ t_start = time.time()
|
|
|
+ predictions = model([img_tensor.to(device)])
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'Prediction used: {t_end - t_start:.4f} seconds')
|
|
|
+
|
|
|
+ boxes = predictions[0]['boxes'].shape
|
|
|
+ lines = predictions[-1]['wires']['lines'].shape
|
|
|
+ lines_scores = predictions[-1]['wires']['score'].shape
|
|
|
+ print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
|
|
|
+
|
|
|
+ t_start = time.time()
|
|
|
+ pred = box_line_optimized_parallel(predictions)
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
|
|
|
+
|
|
|
+
|
|
|
+ if not pred:
|
|
|
+ print("No valid predictions found. Skipping visualization.")
|
|
|
+ return
|
|
|
+
|
|
|
+
|
|
|
+ show_predict(img_tensor, pred, t_start)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ t_start = time.time()
|
|
|
+ print(f'Start to predict: {t_start}')
|
|
|
+ model = linenet_resnet50_fpn().to(device)
|
|
|
+ pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
|
|
|
+
|
|
|
+
|
|
|
+ img_path = r'C:\Users\m2337\Desktop\9.jpg'
|
|
|
+ predict(pt_path, model, img_path)
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'Total prediction time: {t_end - t_start:.4f} seconds')
|