xue50 3 miesięcy temu
rodzic
commit
56e3377bbe

+ 2 - 1
.gitignore

@@ -2,4 +2,5 @@
 *.pt
 runs
 __pycache__
-train_results
+train_results
+/logs

+ 0 - 1
lcnn/models/fasterrcnn_resnet50.py

@@ -111,7 +111,6 @@ class Fasterrcnn_resnet50(nn.Module):
             box_all = self.model(x, target1)
             return outputs, feature_, box_all
 
-
 def fasterrcnn_resnet50(**kwargs):
     model = Fasterrcnn_resnet50(
         num_classes=kwargs.get("num_classes", 5),

+ 1 - 0
lcnn/models/line_vectorizer.py

@@ -124,6 +124,7 @@ class LineVectorizer(nn.Module):
             result["preds"]["lines"] = torch.cat(lines)
             result["preds"]["score"] = torch.cat(score)
             result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+            # print(result)
             result["box"] = result['aaa']
             del result['aaa']
             if len(jcs[i]) > 1:

+ 9 - 5
lcnn/models/multitask_learner.py

@@ -40,14 +40,17 @@ class MultitaskLearner(nn.Module):
     def forward(self, input_dict):
         image = input_dict["image"]
         target_b = input_dict["target_b"]
+        # if input_dict["mode"] == "training":
+        #     outputs, feature, aaa = self.backbone(image, input_dict["mode"], target_b)  # train时aaa是损失,val时是box
+        #
+        # else:  # Inference mode
+        #     outputs, feature, aaa = self.backbone(image, input_dict["mode"])  # train时aaa是损失,val时是box
+
         outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
 
         result = {"feature": feature}
+        result["aaa"] = aaa
         batch, channel, row, col = outputs[0].shape
-        # print(f"batch:{batch}")
-        # print(f"channel:{channel}")
-        # print(f"row:{row}")
-        # print(f"col:{col}")
 
         T = input_dict["target"].copy()
         n_jtyp = T["junc_map"].shape[1]
@@ -74,6 +77,7 @@ class MultitaskLearner(nn.Module):
                     "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                 }
                 if input_dict["mode"] == "testing":
+                    # result["aaa"] = aaa
                     return result
 
             L = OrderedDict()
@@ -94,7 +98,7 @@ class MultitaskLearner(nn.Module):
                 L[loss_name].mul_(loss_weight[loss_name])
             losses.append(L)
         result["losses"] = losses
-        result["aaa"] = aaa
+        # result["aaa"] = aaa
         return result
 
 

+ 173 - 0
predict.py

@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+"""Process an image with the trained neural network
+Usage:
+    demo.py [options] <yaml-config> <checkpoint> <images>...
+    demo.py (-h | --help )
+
+Arguments:
+   <yaml-config>                 Path to the yaml hyper-parameter file
+   <checkpoint>                  Path to the checkpoint
+   <images>                      Path to images
+
+Options:
+   -h --help                     Show this screen.
+   -d --devices <devices>        Comma seperated GPU devices [default: 0]
+"""
+
+# 终端运行   python ./predict.py -d 0 config/wireframe.yaml <path-to-pretrained-pth> <path-to-image>
+
+import os
+import os.path as osp
+import pprint
+import random
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import skimage.io
+import skimage.transform
+import torch
+import yaml
+from docopt import docopt
+
+import lcnn
+from lcnn.config import C, M
+from lcnn.models.line_vectorizer import LineVectorizer
+from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
+from lcnn.postprocess import postprocess
+from lcnn.utils import recursive_to
+from torchvision.utils import draw_bounding_boxes
+
+PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+cmap = plt.get_cmap("jet")
+# # get_cmap 函数是 Matplotlib 中用于获取色彩映射对象的关键函数。它可以接受色彩映射的名称作为参数,返回相应的色彩映射对象。
+norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)   # 颜色映射的颜色条
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def main():
+    args = docopt(__doc__)
+    config = {
+        # 数据集配置
+        'datadir': r'D:\python\PycharmProjects\data',  # 数据集目录
+        'config_file': 'config/wireframe.yaml',  # 配置文件路径
+
+        # GPU配置
+        'devices': '0',  # 使用的GPU设备
+        'identifier': 'fasterrcnn_resnet50',  # 训练标识符 stacked_hourglass unet
+
+    }
+
+    # 更新配置
+    C.update(C.from_yaml(filename=config['config_file']))
+    M.update(C.model)
+
+    random.seed(0)
+    np.random.seed(0)
+    torch.manual_seed(0)
+
+    # 设备配置
+    device_name = "cpu"
+    os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
+
+    if torch.cuda.is_available():
+        device_name = "cuda"
+        torch.backends.cudnn.deterministic = True
+        torch.cuda.manual_seed(0)
+        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    else:
+        print("CUDA is not available")
+
+    device = torch.device(device_name)
+
+    checkpoint = torch.load(args["<checkpoint>"], map_location=device)   ##############
+
+    # Load model   # backbone
+    model = lcnn.models.fasterrcnn_resnet50(
+        # num_stacks=M.num_stacks,
+        num_classes=sum(sum(M.head_size, [])),
+    )
+    # model = lcnn.models.hg(
+    #     depth=M.depth,
+    #     head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
+    #     num_stacks=M.num_stacks,
+    #     num_blocks=M.num_blocks,
+    #     num_classes=sum(sum(M.head_size, [])),
+    # )
+    model = MultitaskLearner(model)
+    model = LineVectorizer(model)
+    model.load_state_dict(checkpoint["model_state_dict"])
+    model = model.to(device)
+    model.eval()
+
+    for imname in args["<images>"]:
+        print(f"Processing {imname}")
+        im = skimage.io.imread(imname)
+        if im.ndim == 2:
+            im = np.repeat(im[:, :, None], 3, 2)
+        im = im[:, :, :3]
+        im_resized = skimage.transform.resize(im, (512, 512)) * 255
+        image = (im_resized - M.image.mean) / M.image.stddev
+        image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
+        with torch.no_grad():
+            input_dict = {
+                "image": image.to(device),
+                "meta": [
+                    {
+                        "junc_coords": torch.zeros(1, 2).to(device),
+                        "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
+                        "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                        "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                    }
+                ],
+                "target": {
+                    "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+                    "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+                },
+                "target_b": None,
+                "mode": "testing",
+            }
+            result = model(input_dict)
+            # print(result)
+            H = result["preds"]
+
+            boxed_image = draw_bounding_boxes((image[0] * 255).to(torch.uint8), result["box"][0]["boxes"],
+                                          colors="yellow", width=1)
+
+        lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+        scores = H["score"][0].cpu().numpy()
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+
+        # postprocess lines to remove overlapped lines
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for i, t in enumerate([0.5, 0.95]):
+            plt.gca().set_axis_off()
+            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+            plt.margins(0, 0)
+            for (a, b), s in zip(nlines, nscores):
+                if s < t:
+                    continue
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+                plt.scatter(a[1], a[0], **PLTOPTS)
+                plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.imshow(im)
+            plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
+            plt.show()
+            plt.close()
+
+
+if __name__ == "__main__":
+    main()