Procházet zdrojové kódy

添加WirePredictor

RenLiqiang před 5 měsíci
rodič
revize
93e69bfeeb

+ 6 - 0
models/keypoint/kepointrcnn.py

@@ -8,6 +8,7 @@ import numpy as np
 import torch
 import torchvision
 from torch import nn
+from torch.nn.modules.module import T
 from torchvision.io import read_image
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
@@ -29,6 +30,7 @@ class KeypointRCNNModel(nn.Module):
         self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
                                                                               num_keypoints=num_keypoints,
                                                                               progress=False)
+        
         if transforms is None:
             self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
         # if num_classes != 0:
@@ -69,6 +71,10 @@ class KeypointRCNNModel(nn.Module):
         self.__model.load_state_dict(state_dict)
         # return super().load_state_dict(state_dict, strict)
 
+    def eval(self: T) -> T:
+        self.__model.eval()
+        # return super().eval()
+
 
 if __name__ == '__main__':
     # ins_model = MaskRCNNModel(num_classes=5)

+ 4 - 4
models/keypoint/keypoint_dataset.py

@@ -38,8 +38,8 @@ class KeypointDataset(BaseDataset):
         self.data_path = dataset_path
         print(f'data_path:{dataset_path}')
         self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
         self.imgs = os.listdir(self.img_path)
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type
@@ -198,6 +198,6 @@ class KeypointDataset(BaseDataset):
 
 
 if __name__ == '__main__':
-    path=r"D:\python\PycharmProjects\data"
+    path=r"I:\datasets\wirenet_1000"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
-    dataset.show(0)
+    dataset.show(7)

+ 77 - 0
models/keypoint/test_predict.py

@@ -0,0 +1,77 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints, draw_bounding_boxes
+
+from models.keypoint.kepointrcnn import KeypointRCNNModel
+
+
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+
+# img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png"
+img_int = read_image(img_path)
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+
+transforms = weights.transforms()
+
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = KeypointRCNNModel(num_keypoints=2)
+
+print(f'start to load pretraine weight!')
+model.load_weight('./train_results/20241226_171710/weights/best.pt')
+print(f'loaded weight !!!')
+
+# model.to(device)
+model.eval()
+# model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+# model = model.eval()
+t1 = time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+print(f't1:{t1}')
+outputs = model([img])
+t2 = time.time()
+print(f'time:{t2 - t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+boxes= outputs[0]['boxes']
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.001
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+res_box=draw_bounding_boxes(img_int,boxes)
+show(res_box)
+plt.show()

+ 1 - 1
models/keypoint/train.yaml

@@ -1,6 +1,6 @@
 
 
-dataset_path: I:/wirenet_dateset
+dataset_path: I:\datasets\wirenet_1000
 
 #train parameters
 num_classes: 2

+ 21 - 11
models/keypoint/trainer.py

@@ -88,7 +88,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, wr
         metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
         metric_logger.update(lr=optimizer.param_groups[0]["lr"])
 
-    return metric_logger
+    return metric_logger, total_train_loss
 
 
 cmap = plt.get_cmap("jet")
@@ -102,27 +102,29 @@ def c(x):
 
 
 def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0)
+    im = img.permute(1, 2, 0)   # [512, 512, 3]
     writer.add_image("ori", im, epoch, dataformats="HWC")
 
     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
                                       colors="yellow", width=1)
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    print(f'box:{pred["boxes"][:5,:]}')
+    print(f'line:{pred["keypoints"][:5,:]}')
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
     # H = pred[1]['wires']
     lines = pred["keypoints"].detach().cpu().numpy()
     scores = pred["keypoints_scores"].detach().cpu().numpy()
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
+    # for i in range(1, len(lines)):
+    #     if (lines[i] == lines[0]).all():
+    #         lines = lines[:i]
+    #         scores = scores[:i]
+    #         break
 
     # postprocess lines to remove overlapped lines
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
-    print(f'nscores:{nscores}')
+    # print(f'nscores:{nscores}')
 
     for i, t in enumerate([0.5]):
         plt.gca().set_axis_off()
@@ -184,6 +186,10 @@ def evaluate(model, data_loader, epoch, writer, device):
         if batch_idx == 0:
             show_line(images[0], outputs[0], epoch, writer)
 
+        # print(f'outputs:{outputs}')
+        # print(f'outputs[0]:{outputs[0]}')
+
+
     #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
     #     model_time = time.time() - model_time
     #
@@ -278,7 +284,7 @@ def train(model, **kwargs):
                                    dataset_type='val')
 
     train_sampler = torch.utils.data.RandomSampler(dataset)
-    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    test_sampler = torch.utils.data.RandomSampler(dataset_test)
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
     train_collate_fn = utils.collate_fn
     data_loader = torch.utils.data.DataLoader(
@@ -288,6 +294,7 @@ def train(model, **kwargs):
         dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
     )
 
+
     img_results_path = os.path.join(train_result_ptath, 'img_results')
     if os.path.exists(train_result_ptath):
         pass
@@ -303,13 +310,13 @@ def train(model, **kwargs):
     total_train_loss = 0.0
 
     for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
+        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):
             os.remove(f'{wts_path}/last.pt')
         torch.save(model.state_dict(), f'{wts_path}/last.pt')
-        write_metric_logs(epoch, metric_logger, writer)
+        # write_metric_logs(epoch, metric_logger, writer)
         if epoch == 0:
             best_loss = losses;
         if best_loss >= losses:
@@ -319,6 +326,9 @@ def train(model, **kwargs):
             torch.save(model.state_dict(), f'{wts_path}/best.pt')
 
         evaluate(model, data_loader_test, epoch, writer, device=device)
+        avg_train_loss = total_train_loss / len(data_loader)
+
+        writer.add_scalar('Loss/train', avg_train_loss, epoch)
 
 
 def get_transform(is_train, **kwargs):

+ 70 - 0
models/wirenet2/WirePredictor.py

@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class WirePredictor(nn.Module):
+    def __init__(self, in_channels=4, out_channels=1, init_features=32):
+        super(WirePredictor, self).__init__()
+
+        features = init_features
+        self.encoder1 = self._block(in_channels, features, name="enc1")
+        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.encoder2 = self._block(features, features * 2, name="enc2")
+        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+        self.bottleneck = self._block(features * 2, features * 4, name="bottleneck")
+
+        self.upconv2 = nn.ConvTranspose2d(
+            features * 4, features * 2, kernel_size=2, stride=2
+        )
+        self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2")
+        self.upconv1 = nn.ConvTranspose2d(
+            features * 2, features, kernel_size=2, stride=2
+        )
+        self.decoder1 = self._block(features * 2, features, name="dec1")
+
+        # Output for line segment mask
+        self.conv_mask = nn.Conv2d(
+            in_channels=features, out_channels=out_channels, kernel_size=1
+        )
+
+        # Output for normal vectors (2 channels for x and y components)
+        self.conv_normals = nn.Conv2d(
+            in_channels=features, out_channels=2, kernel_size=1
+        )
+
+    def forward(self, x):
+        enc1 = self.encoder1(x)
+        enc2 = self.encoder2(self.pool1(enc1))
+
+        bottleneck = self.bottleneck(self.pool2(enc2))
+
+        dec2 = self.upconv2(bottleneck)
+        dec2 = torch.cat((dec2, enc2), dim=1)
+        dec2 = self.decoder2(dec2)
+        dec1 = self.upconv1(dec2)
+        dec1 = torch.cat((dec1, enc1), dim=1)
+        dec1 = self.decoder1(dec1)
+
+        mask = torch.sigmoid(self.conv_mask(dec1))
+        normals = torch.tanh(self.conv_normals(dec1))  # Normalize to [-1, 1]
+
+        return mask, normals
+
+    def _block(self, in_channels, features, name):
+        return nn.Sequential(
+            nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(num_features=features),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(num_features=features),
+            nn.ReLU(inplace=True),
+        )
+
+# 测试模型
+if __name__ == "__main__":
+    model = WirePredictor()
+    x = torch.randn((1, 4, 128, 128))  # 包含法向量信息的输入大小为 128x128
+    with torch.no_grad():
+        output_mask, output_normals = model(x)
+        print(output_mask.shape, output_normals.shape)  # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)