RenLiqiang 5 місяців тому
батько
коміт
267dd8a572

+ 2 - 1
models/ins/trainer.py

@@ -80,7 +80,8 @@ def train(model, **kwargs):
     # 默认参数
     default_params = {
         'dataset_path': '/path/to/dataset',
-        'num_classes': 10,
+        'num_classes': 2,
+        'num_keypoints':2,
         'opt': 'adamw',
         'batch_size': 2,
         'epochs': 10,

+ 0 - 0
models/keypoint/__init__.py


+ 78 - 0
models/keypoint/kepointrcnn.py

@@ -0,0 +1,78 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.keypoint.trainer import train_cfg
+
+from tools import utils
+
+
+class KeypointRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
+        super(KeypointRCNNModel, self).__init__()
+        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+        self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
+                                                                              num_keypoints=num_keypoints,
+                                                                              progress=False)
+        if transforms is None:
+            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
+        # if num_classes != 0:
+        #     self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes = parameters['num_classes']
+        num_keypoints = parameters['num_keypoints']
+        # print(f'num_classes:{num_classes}')
+        # self.set_num_classes(num_classes)
+        self.num_keypoints = num_keypoints
+        train_cfg(self.__model, cfg)
+
+    # def set_num_classes(self, num_classes):
+    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+    #
+    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+    #     in_channels = self.__model.roi_heads.keypoint_predictor.
+    #     hidden_layer = 256
+    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
+    #                                                               num_classes=num_classes)
+    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    keypoint_model = KeypointRCNNModel(num_keypoints=17)
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    keypoint_model.train(cfg='train.yaml')

+ 202 - 0
models/keypoint/keypoint_dataset.py

@@ -0,0 +1,202 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class KeypointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        target["labels"] = torch.stack(labels)
+        # print(f'labels:{target["labels"]}')
+        target["boxes"] = line_boxes(target)
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        keypoints= wire_labels["junc_coords"]
+        keypoints[:,2]=1
+        # keypoints[:,0]=keypoints[:,0]/shape[0]
+        # keypoints[:, 1] = keypoints[:, 1] / shape[1]
+        target["keypoints"]=keypoints
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        print(f'keypoints:{target["keypoints"].shape}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+
+if __name__ == '__main__':
+    path=r"I:\wirenet_dateset"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    dataset.show(0)

+ 32 - 0
models/keypoint/train.yaml

@@ -0,0 +1,32 @@
+
+
+dataset_path: I:/wirenet_dateset
+
+#train parameters
+num_classes: 2
+num_keypoints: 17
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: pixel
+enable_logs: True
+augmentation: False
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 212 - 0
models/keypoint/trainer.py

@@ -0,0 +1,212 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.keypoint.keypoint_dataset import KeypointDataset
+from tools import utils, presets
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        # print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            print(f'loss_dict:{loss_dict}')
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = KeypointDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 4 - 0
models/wirenet/head.py

@@ -992,6 +992,7 @@ class RoIHeads(nn.Module):
             loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
             losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
         else:
+            print('result append boxes!!!')
             boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
             num_images = len(boxes)
             for i in range(num_images):
@@ -1047,6 +1048,7 @@ class RoIHeads(nn.Module):
         # keep none checks in if conditional so torchscript will conditionally
         # compile each branch
         if self.has_keypoint():
+
             keypoint_proposals = [p["boxes"] for p in result]
             if self.training:
                 # during training, only focus on positive boxes
@@ -1101,6 +1103,7 @@ class RoIHeads(nn.Module):
             losses.update(loss_keypoint)
 
         if self.has_wirepoint():
+            # print(f'result:{result}')
             wirepoint_proposals = [p["boxes"] for p in result]
             if self.training:
                 # during training, only focus on positive boxes
@@ -1200,6 +1203,7 @@ class RoIHeads(nn.Module):
 #     return merged_features
 
 def merge_features(features, proposals):
+    print(f'features:{features.shape}')
     def diagnose_input(features, proposals):
         """诊断输入数据"""
         print("Input Diagnostics:")

+ 1 - 1
models/wirenet/wirenet.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: D:/python/PycharmProjects/data
+  datadir: I:/wirenet_dateset
   resume_from:
   num_workers: 4
   tensorboard_port: 0

+ 15 - 15
models/wirenet/wirepoint_rcnn.py

@@ -811,26 +811,26 @@ if __name__ == '__main__':
         print(f"epoch:{epoch}")
         model.train()
 
-        # for imgs, targets in data_loader_train:
-        # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-        # loss = _loss(losses)
-        # print(loss)
+        for imgs, targets in data_loader_train:
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            loss = _loss(losses)
+            print(loss)
         # optimizer.zero_grad()
         # loss.backward()
         # optimizer.step()
         # writer_loss(writer, losses, epoch)
 
-        model.eval()
-        with torch.no_grad():
-            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                pred = model(move_to_device(imgs, device))
-                # print(f"pred:{pred}")
-
-                if batch_idx == 0:
-                    result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
-                    print(imgs[0].shape)  # [3,512,512]
-                    # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
-                    _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
+        # model.eval()
+        # with torch.no_grad():
+        #     for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+        #         pred = model(move_to_device(imgs, device))
+        #         # print(f"pred:{pred}")
+        #
+        #         if batch_idx == 0:
+        #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
+        #             print(imgs[0].shape)  # [3,512,512]
+        #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
+        #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
                     # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
 
 # imgs, targets = next(iter(data_loader))