浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
e122a941f2
共有 61 个文件被更改,包括 126 次插入115 次删除
  1. 3 3
      models/wirenet/head.py
  2. 二进制
      models/wirenet/logs/events.out.tfevents.1733809710.Li.70964.0
  3. 二进制
      models/wirenet/logs/events.out.tfevents.1733809950.Li.136664.0
  4. 二进制
      models/wirenet/logs/events.out.tfevents.1733809982.Li.61568.0
  5. 二进制
      models/wirenet/logs/events.out.tfevents.1733810047.Li.97044.0
  6. 二进制
      models/wirenet/logs/events.out.tfevents.1733810063.Li.102940.0
  7. 二进制
      models/wirenet/logs/events.out.tfevents.1733810109.Li.100936.0
  8. 二进制
      models/wirenet/logs/events.out.tfevents.1733810129.Li.86956.0
  9. 二进制
      models/wirenet/logs/events.out.tfevents.1733813209.Li.103040.0
  10. 二进制
      models/wirenet/logs/events.out.tfevents.1733813254.Li.16020.0
  11. 二进制
      models/wirenet/logs/events.out.tfevents.1733813396.Li.35400.0
  12. 二进制
      models/wirenet/logs/events.out.tfevents.1733813545.Li.136468.0
  13. 二进制
      models/wirenet/logs/events.out.tfevents.1733813995.Li.4024.0
  14. 二进制
      models/wirenet/logs/events.out.tfevents.1733814102.Li.32552.0
  15. 二进制
      models/wirenet/logs/events.out.tfevents.1733814114.Li.76920.0
  16. 二进制
      models/wirenet/logs/events.out.tfevents.1733814164.Li.69088.0
  17. 二进制
      models/wirenet/logs/events.out.tfevents.1733814248.Li.25404.0
  18. 二进制
      models/wirenet/logs/events.out.tfevents.1733814278.Li.56252.0
  19. 二进制
      models/wirenet/logs/events.out.tfevents.1733814289.Li.19808.0
  20. 二进制
      models/wirenet/logs/events.out.tfevents.1733814338.Li.62100.0
  21. 二进制
      models/wirenet/logs/events.out.tfevents.1733814354.Li.110740.0
  22. 二进制
      models/wirenet/logs/events.out.tfevents.1733814411.Li.115852.0
  23. 二进制
      models/wirenet/logs/events.out.tfevents.1733814426.Li.53768.0
  24. 二进制
      models/wirenet/logs/events.out.tfevents.1733814454.Li.26444.0
  25. 二进制
      models/wirenet/logs/events.out.tfevents.1733814473.Li.116808.0
  26. 二进制
      models/wirenet/logs/events.out.tfevents.1733814601.Li.116640.0
  27. 二进制
      models/wirenet/logs/events.out.tfevents.1733814690.Li.43088.0
  28. 二进制
      models/wirenet/logs/events.out.tfevents.1733814808.Li.94012.0
  29. 二进制
      models/wirenet/logs/events.out.tfevents.1733814839.Li.66244.0
  30. 二进制
      models/wirenet/logs/events.out.tfevents.1733814851.Li.112612.0
  31. 二进制
      models/wirenet/logs/events.out.tfevents.1733814885.Li.90468.0
  32. 二进制
      models/wirenet/logs/events.out.tfevents.1733814902.Li.122648.0
  33. 二进制
      models/wirenet/logs/events.out.tfevents.1733815046.Li.61904.0
  34. 二进制
      models/wirenet/logs/events.out.tfevents.1733815059.Li.133128.0
  35. 二进制
      models/wirenet/logs/events.out.tfevents.1733815249.Li.20304.0
  36. 二进制
      models/wirenet/logs/events.out.tfevents.1733815303.Li.70972.0
  37. 二进制
      models/wirenet/logs/events.out.tfevents.1733815325.Li.2968.0
  38. 二进制
      models/wirenet/logs/events.out.tfevents.1733815358.Li.26132.0
  39. 二进制
      models/wirenet/logs/events.out.tfevents.1733815384.Li.76524.0
  40. 二进制
      models/wirenet/logs/events.out.tfevents.1733815400.Li.69232.0
  41. 二进制
      models/wirenet/logs/events.out.tfevents.1733815444.Li.75020.0
  42. 二进制
      models/wirenet/logs/events.out.tfevents.1733815624.Li.40676.0
  43. 二进制
      models/wirenet/logs/events.out.tfevents.1733815646.Li.50320.0
  44. 二进制
      models/wirenet/logs/events.out.tfevents.1733815663.Li.71412.0
  45. 二进制
      models/wirenet/logs/events.out.tfevents.1733815675.Li.69040.0
  46. 二进制
      models/wirenet/logs/events.out.tfevents.1733815695.Li.36932.0
  47. 二进制
      models/wirenet/logs/events.out.tfevents.1733815714.Li.86692.0
  48. 二进制
      models/wirenet/logs/events.out.tfevents.1733815777.Li.89856.0
  49. 二进制
      models/wirenet/logs/events.out.tfevents.1733815795.Li.12740.0
  50. 二进制
      models/wirenet/logs/events.out.tfevents.1733815876.Li.13780.0
  51. 二进制
      models/wirenet/logs/events.out.tfevents.1733815943.Li.46524.0
  52. 二进制
      models/wirenet/logs/events.out.tfevents.1733815957.Li.44540.0
  53. 二进制
      models/wirenet/logs/events.out.tfevents.1733815969.Li.101524.0
  54. 二进制
      models/wirenet/logs/events.out.tfevents.1733816053.Li.68472.0
  55. 二进制
      models/wirenet/logs/events.out.tfevents.1733816083.Li.73196.0
  56. 二进制
      models/wirenet/logs/events.out.tfevents.1733816092.Li.47628.0
  57. 二进制
      models/wirenet/logs/events.out.tfevents.1733816473.Li.113256.0
  58. 二进制
      models/wirenet/logs/events.out.tfevents.1733818550.Li.118296.0
  59. 二进制
      models/wirenet/logs/events.out.tfevents.1733818566.Li.79004.0
  60. 二进制
      models/wirenet/logs/events.out.tfevents.1733818796.Li.30344.0
  61. 123 112
      models/wirenet/wirepoint_rcnn.py

+ 3 - 3
models/wirenet/head.py

@@ -163,7 +163,6 @@ def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
 def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
     result = {}
     result["wires"] = {}
-    print(f"ps1:{ps}")
     p = torch.cat(ps)
     s = torch.sigmoid(input)
     b = s > 0.5
@@ -1129,7 +1128,7 @@ class RoIHeads(nn.Module):
             outputs = merge_features(outputs, wirepoint_proposals)
             wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
 
-            # print(f'outpust:{outputs.shape}')
+            print(f'outpust:{outputs.shape}')
 
             wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
@@ -1148,6 +1147,7 @@ class RoIHeads(nn.Module):
             else:
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
                 result.append(pred)
+                loss_wirepoint = {}
 
             # tmp = wirepoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
@@ -1231,7 +1231,7 @@ def merge_features(features, proposals):
 
     try:
         # 诊断输入(可选)
-        diagnose_input(features, proposals)
+        # diagnose_input(features, proposals)
 
         # 验证输入
         validate_inputs(features, proposals)

二进制
models/wirenet/logs/events.out.tfevents.1733809710.Li.70964.0


二进制
models/wirenet/logs/events.out.tfevents.1733809950.Li.136664.0


二进制
models/wirenet/logs/events.out.tfevents.1733809982.Li.61568.0


二进制
models/wirenet/logs/events.out.tfevents.1733810047.Li.97044.0


二进制
models/wirenet/logs/events.out.tfevents.1733810063.Li.102940.0


二进制
models/wirenet/logs/events.out.tfevents.1733810109.Li.100936.0


二进制
models/wirenet/logs/events.out.tfevents.1733810129.Li.86956.0


二进制
models/wirenet/logs/events.out.tfevents.1733813209.Li.103040.0


二进制
models/wirenet/logs/events.out.tfevents.1733813254.Li.16020.0


二进制
models/wirenet/logs/events.out.tfevents.1733813396.Li.35400.0


二进制
models/wirenet/logs/events.out.tfevents.1733813545.Li.136468.0


二进制
models/wirenet/logs/events.out.tfevents.1733813995.Li.4024.0


二进制
models/wirenet/logs/events.out.tfevents.1733814102.Li.32552.0


二进制
models/wirenet/logs/events.out.tfevents.1733814114.Li.76920.0


二进制
models/wirenet/logs/events.out.tfevents.1733814164.Li.69088.0


二进制
models/wirenet/logs/events.out.tfevents.1733814248.Li.25404.0


二进制
models/wirenet/logs/events.out.tfevents.1733814278.Li.56252.0


二进制
models/wirenet/logs/events.out.tfevents.1733814289.Li.19808.0


二进制
models/wirenet/logs/events.out.tfevents.1733814338.Li.62100.0


二进制
models/wirenet/logs/events.out.tfevents.1733814354.Li.110740.0


二进制
models/wirenet/logs/events.out.tfevents.1733814411.Li.115852.0


二进制
models/wirenet/logs/events.out.tfevents.1733814426.Li.53768.0


二进制
models/wirenet/logs/events.out.tfevents.1733814454.Li.26444.0


二进制
models/wirenet/logs/events.out.tfevents.1733814473.Li.116808.0


二进制
models/wirenet/logs/events.out.tfevents.1733814601.Li.116640.0


二进制
models/wirenet/logs/events.out.tfevents.1733814690.Li.43088.0


二进制
models/wirenet/logs/events.out.tfevents.1733814808.Li.94012.0


二进制
models/wirenet/logs/events.out.tfevents.1733814839.Li.66244.0


二进制
models/wirenet/logs/events.out.tfevents.1733814851.Li.112612.0


二进制
models/wirenet/logs/events.out.tfevents.1733814885.Li.90468.0


二进制
models/wirenet/logs/events.out.tfevents.1733814902.Li.122648.0


二进制
models/wirenet/logs/events.out.tfevents.1733815046.Li.61904.0


二进制
models/wirenet/logs/events.out.tfevents.1733815059.Li.133128.0


二进制
models/wirenet/logs/events.out.tfevents.1733815249.Li.20304.0


二进制
models/wirenet/logs/events.out.tfevents.1733815303.Li.70972.0


二进制
models/wirenet/logs/events.out.tfevents.1733815325.Li.2968.0


二进制
models/wirenet/logs/events.out.tfevents.1733815358.Li.26132.0


二进制
models/wirenet/logs/events.out.tfevents.1733815384.Li.76524.0


二进制
models/wirenet/logs/events.out.tfevents.1733815400.Li.69232.0


二进制
models/wirenet/logs/events.out.tfevents.1733815444.Li.75020.0


二进制
models/wirenet/logs/events.out.tfevents.1733815624.Li.40676.0


二进制
models/wirenet/logs/events.out.tfevents.1733815646.Li.50320.0


二进制
models/wirenet/logs/events.out.tfevents.1733815663.Li.71412.0


二进制
models/wirenet/logs/events.out.tfevents.1733815675.Li.69040.0


二进制
models/wirenet/logs/events.out.tfevents.1733815695.Li.36932.0


二进制
models/wirenet/logs/events.out.tfevents.1733815714.Li.86692.0


二进制
models/wirenet/logs/events.out.tfevents.1733815777.Li.89856.0


二进制
models/wirenet/logs/events.out.tfevents.1733815795.Li.12740.0


二进制
models/wirenet/logs/events.out.tfevents.1733815876.Li.13780.0


二进制
models/wirenet/logs/events.out.tfevents.1733815943.Li.46524.0


二进制
models/wirenet/logs/events.out.tfevents.1733815957.Li.44540.0


二进制
models/wirenet/logs/events.out.tfevents.1733815969.Li.101524.0


二进制
models/wirenet/logs/events.out.tfevents.1733816053.Li.68472.0


二进制
models/wirenet/logs/events.out.tfevents.1733816083.Li.73196.0


二进制
models/wirenet/logs/events.out.tfevents.1733816092.Li.47628.0


二进制
models/wirenet/logs/events.out.tfevents.1733816473.Li.113256.0


二进制
models/wirenet/logs/events.out.tfevents.1733818550.Li.118296.0


二进制
models/wirenet/logs/events.out.tfevents.1733818566.Li.79004.0


二进制
models/wirenet/logs/events.out.tfevents.1733818796.Li.30344.0


+ 123 - 112
models/wirenet/wirepoint_rcnn.py

@@ -1,6 +1,7 @@
 import os
 from typing import Optional, Any
 
+import cv2
 import numpy as np
 import torch
 from tensorboardX import SummaryWriter
@@ -27,9 +28,17 @@ from models.wirenet.wirepoint_dataset import WirePointDataset
 from tools import utils
 
 from torch.utils.tensorboard import SummaryWriter
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from skimage import io
+import os.path as osp
+
 
 FEATURE_DIM = 8
 
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+print(f"Using device: {device}")
+
 
 def non_maximum_suppression(a):
     ap = F.max_pool2d(a, 3, stride=1, padding=1)
@@ -124,7 +133,7 @@ class WirepointRCNN(FasterRCNN):
 
         if wirepoint_head is None:
             keypoint_layers = tuple(512 for _ in range(8))
-            print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
+            # print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
             wirepoint_head = WirepointHead(out_channels, keypoint_layers)
 
         if wirepoint_predictor is None:
@@ -291,7 +300,7 @@ class WirepointPredictor(nn.Module):
         #     print(f'out:{out.shape}')
         # outputs=merge_features(outputs,100)
         batch, channel, row, col = inputs.shape
-        print(f'outputs:{inputs.shape}')
+        # print(f'outputs:{inputs.shape}')
         # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
 
         if targets is not None:
@@ -316,18 +325,18 @@ class WirepointPredictor(nn.Module):
         else:
             self.training = False
             t = {
-                "junc_coords": torch.zeros(1, 2),
-                "jtyp": torch.zeros(1, dtype=torch.uint8),
-                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+                "junc_coords": torch.zeros(1, 2).to(device),
+                "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
             }
             wires_targets = [t for b in range(inputs.size(0))]
 
             wires_meta = {
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
             }
 
         T = wires_meta.copy()
@@ -399,7 +408,6 @@ class WirepointPredictor(nn.Module):
         x, y = torch.cat(xs), torch.cat(ys)
         f = torch.cat(fs)
         x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-        print(f"pstest{ps}")
         x = torch.cat([x, f], 1)
         x = x.to(dtype=torch.float32)
         x = self.fc2(x).flatten()
@@ -443,6 +451,9 @@ class WirepointPredictor(nn.Module):
             xy_ = xy[..., None, :]
             del x, y, index
 
+            # print(f"xy_.is_cuda: {xy_.is_cuda}")
+            # print(f"junc.is_cuda: {junc.is_cuda}")
+
             # dist: [N_TYPE, K, N]
             dist = torch.sum((xy_ - junc) ** 2, -1)
             cost, match = torch.min(dist, -1)
@@ -555,6 +566,72 @@ def _loss(losses):
     return total_loss
 
 
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def _plot_samples(self, i, index, result, targets, prefix):
+    fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
+    img = io.imread(fn)
+    imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
+
+    def draw_vecl(lines, sline, juncs, junts, fn):
+        imshow(img)
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+            for i, j in enumerate(junts):
+                if i > 0 and (i == junts[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+        plt.savefig(fn), plt.close()
+
+    junc = targets[i]["junc"].cpu().numpy() * 4
+    jtyp = targets[i]["jtyp"].cpu().numpy()
+    juncs = junc[jtyp == 0]
+    junts = junc[jtyp == 1]
+    rjuncs = result["juncs"][i].cpu().numpy() * 4
+    rjunts = None
+    if "junts" in result:
+        rjunts = result["junts"][i].cpu().numpy() * 4
+
+    lpre = targets[i]["lpre"].cpu().numpy() * 4
+    vecl_target = targets[i]["lpre_label"].cpu().numpy()
+    vecl_result = result["lines"][i].cpu().numpy() * 4
+    score = result["score"][i].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
+    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+
+    img = cv2.imread(f"{prefix}_vecl_a.jpg")
+    img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
+    self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
+
+
 if __name__ == '__main__':
     cfg = 'wirenet.yaml'
     cfg = read_yaml(cfg)
@@ -562,15 +639,15 @@ if __name__ == '__main__':
     print(cfg['model']['n_dyn_negl'])
     # net = WirepointPredictor()
 
-    if torch.cuda.is_available():
-        device_name = "cuda"
-        torch.backends.cudnn.deterministic = True
-        torch.cuda.manual_seed(0)
-        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
-    else:
-        print("CUDA is not available")
-
-    device = torch.device(device_name)
+    # if torch.cuda.is_available():
+    #     device_name = "cuda"
+    #     torch.backends.cudnn.deterministic = True
+    #     torch.cuda.manual_seed(0)
+    #     print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    # else:
+    #     print("CUDA is not available")
+    #
+    # device = torch.device(device_name)
 
     dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
     train_sampler = torch.utils.data.RandomSampler(dataset_train)
@@ -607,17 +684,27 @@ if __name__ == '__main__':
             return data  # 对于非张量类型的数据不做任何改变
 
 
-    def writer_loss(writer, losses):
-        # 记录每个损失项到TensorBoard
-        for key, value in losses.items():
-            if isinstance(value, dict):  # 如果value本身也是一个字典(例如'loss_wirepoint')
-                for subkey, subvalue in value['losses'][0].items():
-                    writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
-            else:
-                writer.add_scalar(key, value.item(), epoch)
+    def writer_loss(writer, losses, epoch):
+        # ??????
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    # ?? wirepoint ??????
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            # ?? .item() ?????
+                            writer.add_scalar(f'loss_wirepoint/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    # ????????
+                    writer.add_scalar(key, value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
 
 
     for epoch in range(cfg['optim']['max_epoch']):
+        print(f"epoch:{epoch}")
         model.train()
 
         for imgs, targets in data_loader_train:
@@ -627,14 +714,18 @@ if __name__ == '__main__':
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
-            writer_loss(writer, losses)
+            writer_loss(writer, losses, epoch)
 
             model.eval()
             with torch.no_grad():
-                for imgs, targets in data_loader_val:
-                    print(111)
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                     pred = model(move_to_device(imgs, device))
-                    print(f"pred:{pred}")
+                    print(f"perd:{pred}")
+
+                # if batch_idx == 0:
+                #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
+                #     H = pred["wires"]
+                #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
 
 # imgs, targets = next(iter(data_loader))
 #
@@ -645,83 +736,3 @@ if __name__ == '__main__':
 # result, losses = model(imgs, targets)
 # print(f'result:{result}')
 # print(f'pred:{losses}')
-'''
-########### predict#############
-
-    img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
-    transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
-    img = read_image(img_path)
-    img = transforms(img)
-
-    img = torch.ones((2, 3, 512, 512))
-    # print(f'img shape:{img.shape}')
-    model.eval()
-    onnx_file_path = "./wirenet.onnx"
-
-    # 导出模型为ONNX格式
-    # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
-    #                   output_names=['output'])
-    # torch.save(model,'./wirenet.pt')
-
-
-
-    # 5. 指定输出的 ONNX 文件名
-    # onnx_file_path = "./wirepoint_rcnn.onnx"
-
-    # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
-    img = [torch.ones((3, 800, 800))]  # 示例输入图像大小为 800x800,3个通道
-
-
-
-    # 指定输出的 ONNX 文件名
-    # onnx_file_path = "./mask_rcnn.onnx"
-
-
-
-    # model_scripted = torch.jit.script(model)
-    # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
-    #                   output_names=["output"])
-    #
-    # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")
-
-    pred=model(img)
-    #
-    print(f'pred:{pred}')
-
-
-
-################################################## end predict
-
-
-
-########## traing ###################################
-    # imgs, targets = next(iter(data_loader))
-
-    # model.train()
-    # pred = model(imgs, targets)
-
-    # class WrapperModule(torch.nn.Module):
-    #     def __init__(self, model):
-    #         super(WrapperModule, self).__init__()
-    #         self.model = model
-    #
-    #     def forward(self,img, targets):
-    #         # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
-    #         return self.model(img,targets)
-
-    # torch.save(model.state_dict(),'./wire.pt')
-    # 包装原始模型
-    # wrapped_model = WrapperModule(model)
-    # # model_scripted = torch.jit.trace(wrapped_model,img)
-    # writer = SummaryWriter('./')
-    # writer.add_graph(wrapped_model, (imgs,targets))
-    # writer.close()
-
-
-    #
-    # print(f'pred:{pred}')
-########## end traing ###################################
-    # for imgs,targets in data_loader:
-    #     print(f'imgs:{imgs}')
-    #     print(f'targets:{targets}')
-'''