|
|
@@ -1,40 +1,234 @@
|
|
|
-# ??roi_head??????????????
|
|
|
-from torch.utils.data.dataset import T_co
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torchvision
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+import tools.transforms as reference_transforms
|
|
|
+from collections import defaultdict
|
|
|
|
|
|
-from models.base.base_dataset import BaseDataset
|
|
|
+from tools import presets
|
|
|
|
|
|
-import glob
|
|
|
import json
|
|
|
-import math
|
|
|
-import os
|
|
|
-import random
|
|
|
-import cv2
|
|
|
-import PIL
|
|
|
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-import matplotlib as mpl
|
|
|
-from torchvision.utils import draw_bounding_boxes
|
|
|
|
|
|
-import numpy as np
|
|
|
-import numpy.linalg as LA
|
|
|
-import torch
|
|
|
-from skimage import io
|
|
|
-from torch.utils.data import Dataset
|
|
|
-from torch.utils.data.dataloader import default_collate
|
|
|
+def get_modules(use_v2):
|
|
|
+ # We need a protected import to avoid the V2 warning in case just V1 is used
|
|
|
+ if use_v2:
|
|
|
+ import torchvision.transforms.v2
|
|
|
+ import torchvision.tv_tensors
|
|
|
+
|
|
|
+ return torchvision.transforms.v2, torchvision.tv_tensors
|
|
|
+ else:
|
|
|
+ return reference_transforms, None
|
|
|
+
|
|
|
+
|
|
|
+class Augmentation:
|
|
|
+ # Note: this transform assumes that the input to forward() are always PIL
|
|
|
+ # images, regardless of the backend parameter.
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ *,
|
|
|
+ data_augmentation,
|
|
|
+ hflip_prob=0.5,
|
|
|
+ mean=(123.0, 117.0, 104.0),
|
|
|
+ backend="pil",
|
|
|
+ use_v2=False,
|
|
|
+ ):
|
|
|
+
|
|
|
+ T, tv_tensors = get_modules(use_v2)
|
|
|
+
|
|
|
+ transforms = []
|
|
|
+ backend = backend.lower()
|
|
|
+ if backend == "tv_tensor":
|
|
|
+ transforms.append(T.ToImage())
|
|
|
+ elif backend == "tensor":
|
|
|
+ transforms.append(T.PILToTensor())
|
|
|
+ elif backend != "pil":
|
|
|
+ raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
|
|
|
+
|
|
|
+ if data_augmentation == "hflip":
|
|
|
+ transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
|
|
|
+ elif data_augmentation == "lsj":
|
|
|
+ transforms += [
|
|
|
+ T.ScaleJitter(target_size=(1024, 1024), antialias=True),
|
|
|
+ # TODO: FixedSizeCrop below doesn't work on tensors!
|
|
|
+ reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
|
|
|
+ T.RandomHorizontalFlip(p=hflip_prob),
|
|
|
+ ]
|
|
|
+ elif data_augmentation == "multiscale":
|
|
|
+ transforms += [
|
|
|
+ T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
|
|
|
+ T.RandomHorizontalFlip(p=hflip_prob),
|
|
|
+ ]
|
|
|
+ elif data_augmentation == "ssd":
|
|
|
+ fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
|
|
|
+ transforms += [
|
|
|
+ T.RandomPhotometricDistort(),
|
|
|
+ T.RandomZoomOut(fill=fill),
|
|
|
+ T.RandomIoUCrop(),
|
|
|
+ T.RandomHorizontalFlip(p=hflip_prob),
|
|
|
+ ]
|
|
|
+ elif data_augmentation == "ssdlite":
|
|
|
+ transforms += [
|
|
|
+ T.RandomIoUCrop(),
|
|
|
+ T.RandomHorizontalFlip(p=hflip_prob),
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
|
|
|
+
|
|
|
+ if backend == "pil":
|
|
|
+ # Note: we could just convert to pure tensors even in v2.
|
|
|
+ transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
|
|
|
+
|
|
|
+ transforms += [T.ToDtype(torch.float, scale=True)]
|
|
|
+
|
|
|
+ if use_v2:
|
|
|
+ transforms += [
|
|
|
+ T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
|
|
|
+ T.SanitizeBoundingBoxes(),
|
|
|
+ T.ToPureTensor(),
|
|
|
+ ]
|
|
|
+
|
|
|
+ self.transforms = T.Compose(transforms)
|
|
|
+
|
|
|
+ def __call__(self, img, target):
|
|
|
+ return self.transforms(img, target)
|
|
|
+
|
|
|
+
|
|
|
+def read_polygon_points(lbl_path, shape):
|
|
|
+ """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
|
|
|
+ polygon_points = []
|
|
|
+ w, h = shape[:2]
|
|
|
+ with open(lbl_path, 'r') as f:
|
|
|
+ lines = f.readlines()
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ parts = line.strip().split()
|
|
|
+ class_id = int(parts[0])
|
|
|
+ points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2) # 读取点坐标
|
|
|
+ points[:, 0] *= h
|
|
|
+ points[:, 1] *= w
|
|
|
+
|
|
|
+ polygon_points.append((class_id, points))
|
|
|
+
|
|
|
+ return polygon_points
|
|
|
+
|
|
|
+
|
|
|
+def read_masks_from_pixels(lbl_path, shape):
|
|
|
+ """读取纯像素点格式的文件,不是轮廓像素点"""
|
|
|
+ h, w = shape
|
|
|
+ masks = []
|
|
|
+ labels = []
|
|
|
+
|
|
|
+ with open(lbl_path, 'r') as reader:
|
|
|
+ lines = reader.readlines()
|
|
|
+ mask_points = []
|
|
|
+ for line in lines:
|
|
|
+ mask = torch.zeros((h, w), dtype=torch.uint8)
|
|
|
+ parts = line.strip().split()
|
|
|
+ # print(f'parts:{parts}')
|
|
|
+ cls = torch.tensor(int(parts[0]), dtype=torch.int64)
|
|
|
+ labels.append(cls)
|
|
|
+ x_array = parts[1::2]
|
|
|
+ y_array = parts[2::2]
|
|
|
+
|
|
|
+ for x, y in zip(x_array, y_array):
|
|
|
+ x = float(x)
|
|
|
+ y = float(y)
|
|
|
+ mask_points.append((int(y * h), int(x * w)))
|
|
|
+
|
|
|
+ for p in mask_points:
|
|
|
+ mask[p] = 1
|
|
|
+ masks.append(mask)
|
|
|
+ reader.close()
|
|
|
+ return labels, masks
|
|
|
+
|
|
|
+
|
|
|
+def create_masks_from_polygons(polygons, image_shape):
|
|
|
+ """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
|
|
|
+ colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
|
|
|
+ masks = []
|
|
|
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
+ for polygon_data, col in zip(polygons, colors):
|
|
|
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
|
|
|
+ # 将多边形顶点转换为 NumPy 数组
|
|
|
+ _, polygon = polygon_data
|
|
|
+ pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
|
|
|
|
|
|
-from tools.presets import DetectionPresetTrain
|
|
|
+ # 使用 OpenCV 的 fillPoly 函数填充多边形
|
|
|
+ # print(f'color:{col[:3]}')
|
|
|
+ cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
|
|
|
+ mask = torch.from_numpy(mask)
|
|
|
+ mask[mask != 0] = 1
|
|
|
+ masks.append(mask)
|
|
|
|
|
|
-def line_boxes1(target):
|
|
|
+ return masks
|
|
|
+
|
|
|
+
|
|
|
+def read_masks_from_txt(label_path, shape):
|
|
|
+ polygon_points = read_polygon_points(label_path, shape)
|
|
|
+ masks = create_masks_from_polygons(polygon_points, shape)
|
|
|
+ labels = [torch.tensor(item[0]) for item in polygon_points]
|
|
|
+
|
|
|
+ return labels, masks
|
|
|
+
|
|
|
+
|
|
|
+def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Compute the bounding boxes around the provided masks.
|
|
|
+
|
|
|
+ Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
|
|
|
+ ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ masks (Tensor[N, H, W]): masks to transform where N is the number of masks
|
|
|
+ and (H, W) are the spatial dimensions.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tensor[N, 4]: bounding boxes
|
|
|
+ """
|
|
|
+ # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
|
|
+ # _log_api_usage_once(masks_to_boxes)
|
|
|
+ if masks.numel() == 0:
|
|
|
+ return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
|
|
|
+
|
|
|
+ n = masks.shape[0]
|
|
|
+
|
|
|
+ bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
|
|
|
+
|
|
|
+ for index, mask in enumerate(masks):
|
|
|
+ y, x = torch.where(mask != 0)
|
|
|
+ bounding_boxes[index, 0] = torch.min(x)
|
|
|
+ bounding_boxes[index, 1] = torch.min(y)
|
|
|
+ bounding_boxes[index, 2] = torch.max(x)
|
|
|
+ bounding_boxes[index, 3] = torch.max(y)
|
|
|
+ # debug to pixel datasets
|
|
|
+
|
|
|
+ if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
|
|
|
+ bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
|
|
|
+ bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
|
|
|
+
|
|
|
+ if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
|
|
|
+ bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
|
|
|
+ bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
|
|
|
+
|
|
|
+ return bounding_boxes
|
|
|
+
|
|
|
+
|
|
|
+def line_boxes(target):
|
|
|
boxs = []
|
|
|
- lines = target.cpu().numpy() * 4
|
|
|
+ lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
+ vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ lines = lpre
|
|
|
+ sline = np.ones(lpre.shape[0])
|
|
|
|
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b)) in enumerate(lines):
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
if i > 0 and (lines[i] == lines[0]).all():
|
|
|
break
|
|
|
+ # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
|
|
|
if a[1] > b[1]:
|
|
|
ymax = a[1] + 10
|
|
|
@@ -50,172 +244,71 @@ def line_boxes1(target):
|
|
|
xmax = b[0] + 10
|
|
|
boxs.append([ymin, xmin, ymax, xmax])
|
|
|
|
|
|
- # if boxs == []:
|
|
|
- # print(target)
|
|
|
-
|
|
|
return torch.tensor(boxs)
|
|
|
|
|
|
|
|
|
-class WirePointDataset(BaseDataset):
|
|
|
- def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
|
|
|
- super().__init__(dataset_path)
|
|
|
-
|
|
|
- self.data_path = dataset_path
|
|
|
- print(f'data_path:{dataset_path}')
|
|
|
- self.transforms = transforms
|
|
|
- self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
|
|
|
- self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
|
|
|
- self.imgs = os.listdir(self.img_path)
|
|
|
- self.lbls = os.listdir(self.lbl_path)
|
|
|
- self.target_type = target_type
|
|
|
- # self.default_transform = DefaultTransform()
|
|
|
-
|
|
|
- def __getitem__(self, index) -> T_co:
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
-
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- w, h = img.size
|
|
|
-
|
|
|
- # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- if self.transforms:
|
|
|
- img, target = self.transforms(img, target)
|
|
|
- else:
|
|
|
- img = self.default_transform(img)
|
|
|
-
|
|
|
- # print(f'img:{img.shape}')
|
|
|
- return img, target
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.imgs)
|
|
|
-
|
|
|
- def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
- # print(f'lbl_path:{lbl_path}')
|
|
|
- with open(lbl_path, 'r') as file:
|
|
|
- lable_all = json.load(file)
|
|
|
-
|
|
|
- n_stc_posl = 300
|
|
|
- n_stc_negl = 40
|
|
|
- use_cood = 0
|
|
|
- use_slop = 0
|
|
|
-
|
|
|
- wire = lable_all["wires"][0] # ??
|
|
|
- line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
|
|
|
- line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
|
|
|
- npos, nneg = len(line_pos_coords), len(line_neg_coords)
|
|
|
- lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
|
|
|
- for i in range(len(lpre)):
|
|
|
- if random.random() > 0.5:
|
|
|
- lpre[i] = lpre[i, ::-1]
|
|
|
- ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
|
- ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
|
- feat = [
|
|
|
- lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
|
|
|
- ldir * use_slop,
|
|
|
- lpre[:, :, 2],
|
|
|
- ]
|
|
|
- feat = np.concatenate(feat, 1)
|
|
|
-
|
|
|
- wire_labels = {
|
|
|
- "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
|
|
|
- "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
|
|
|
- "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
|
|
|
- # ???????????
|
|
|
- "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
|
|
|
- # ??????????
|
|
|
- "lpre": torch.tensor(lpre)[:, :, :2],
|
|
|
- "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
|
|
|
- "lpre_feat": torch.from_numpy(feat),
|
|
|
- "junc_map": torch.tensor(wire['junc_map']["content"]),
|
|
|
- "junc_offset": torch.tensor(wire['junc_offset']["content"]),
|
|
|
- "line_map": torch.tensor(wire['line_map']["content"]),
|
|
|
- }
|
|
|
-
|
|
|
- labels = []
|
|
|
- #
|
|
|
- # if self.target_type == 'polygon':
|
|
|
- # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
|
|
|
- # elif self.target_type == 'pixel':
|
|
|
- # labels = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
-
|
|
|
- # print(torch.stack(masks).shape) # [???, 512, 512]
|
|
|
- target = {}
|
|
|
- # target["labels"] = torch.stack(labels)
|
|
|
-
|
|
|
-
|
|
|
- target["image_id"] = torch.tensor(item)
|
|
|
- # return wire_labels, target
|
|
|
- target["wires"] = wire_labels
|
|
|
- # target["boxes"] = line_boxes(target)
|
|
|
- target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
|
- target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
|
|
|
- # print(f'target["labels"]:{ target["labels"]}')
|
|
|
- # print(f'boxes:{target["boxes"].shape}')
|
|
|
- if target["boxes"].numel() == 0:
|
|
|
- print("Tensor is empty")
|
|
|
- print(f'path:{lbl_path}')
|
|
|
- return target
|
|
|
-
|
|
|
- def show(self, idx):
|
|
|
- image, target = self.__getitem__(idx)
|
|
|
-
|
|
|
- cmap = plt.get_cmap("jet")
|
|
|
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
- sm.set_array([])
|
|
|
-
|
|
|
- def imshow(im):
|
|
|
- plt.close()
|
|
|
- plt.tight_layout()
|
|
|
- plt.imshow(im)
|
|
|
- plt.colorbar(sm, fraction=0.046)
|
|
|
- plt.xlim([0, im.shape[0]])
|
|
|
- plt.ylim([im.shape[0], 0])
|
|
|
-
|
|
|
- def draw_vecl(lines, sline, juncs, junts, fn=None):
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- imshow(io.imread(img_path))
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
- break
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
|
|
|
- if not (juncs[0] == 0).all():
|
|
|
- for i, j in enumerate(juncs):
|
|
|
- if i > 0 and (i == juncs[0]).all():
|
|
|
- break
|
|
|
- plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
|
|
|
-
|
|
|
-
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
- colors="yellow", width=1)
|
|
|
- plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
- plt.show()
|
|
|
-
|
|
|
- plt.show()
|
|
|
- if fn != None:
|
|
|
- plt.savefig(fn)
|
|
|
-
|
|
|
- junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
- jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
|
- juncs = junc[jtyp == 0]
|
|
|
- junts = junc[jtyp == 1]
|
|
|
-
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
-
|
|
|
- # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
- draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
-
|
|
|
-
|
|
|
- def show_img(self, img_path):
|
|
|
- pass
|
|
|
-
|
|
|
-
|
|
|
-# dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
|
|
|
-# for i in dataset_train:
|
|
|
-# a = 1
|
|
|
+def read_polygon_points_wire(lbl_path, shape):
|
|
|
+ """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
|
|
|
+ polygon_points = []
|
|
|
+ w, h = shape[:2]
|
|
|
+ with open(lbl_path, 'r') as f:
|
|
|
+ lines = json.load(f)
|
|
|
+
|
|
|
+ for line in lines["segmentations"]:
|
|
|
+ parts = line["data"]
|
|
|
+ class_id = int(line["cls_id"])
|
|
|
+ points = np.array(parts, dtype=np.float32).reshape(-1, 2) # 读取点坐标
|
|
|
+ points[:, 0] *= h
|
|
|
+ points[:, 1] *= w
|
|
|
+
|
|
|
+ polygon_points.append((class_id, points))
|
|
|
+
|
|
|
+ return polygon_points
|
|
|
+
|
|
|
+
|
|
|
+def read_masks_from_txt_wire(label_path, shape):
|
|
|
+ polygon_points = read_polygon_points_wire(label_path, shape)
|
|
|
+ masks = create_masks_from_polygons(polygon_points, shape)
|
|
|
+
|
|
|
+ labels = [torch.tensor(item[0]) for item in polygon_points]
|
|
|
+
|
|
|
+ return labels, masks
|
|
|
+
|
|
|
+
|
|
|
+def read_masks_from_pixels_wire(lbl_path, shape):
|
|
|
+ """读取纯像素点格式的文件,不是轮廓像素点"""
|
|
|
+ h, w = shape
|
|
|
+ masks = []
|
|
|
+ labels = []
|
|
|
+
|
|
|
+ with open(lbl_path, 'r') as reader:
|
|
|
+ lines = json.load(reader)
|
|
|
+ mask_points = []
|
|
|
+ for line in lines["segmentations"]:
|
|
|
+ # mask = torch.zeros((h, w), dtype=torch.uint8)
|
|
|
+ # parts = line["data"]
|
|
|
+ # print(f'parts:{parts}')
|
|
|
+ cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
|
|
|
+ labels.append(cls)
|
|
|
+ # x_array = parts[0::2]
|
|
|
+ # y_array = parts[1::2]
|
|
|
+ #
|
|
|
+ # for x, y in zip(x_array, y_array):
|
|
|
+ # x = float(x)
|
|
|
+ # y = float(y)
|
|
|
+ # mask_points.append((int(y * h), int(x * w)))
|
|
|
+
|
|
|
+ # for p in mask_points:
|
|
|
+ # mask[p] = 1
|
|
|
+ # masks.append(mask)
|
|
|
+ reader.close()
|
|
|
+ return labels
|
|
|
+
|
|
|
+
|
|
|
+def adjacency_matrix(n, link): # 邻接矩阵
|
|
|
+ mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
|
|
|
+ link = torch.tensor(link)
|
|
|
+ if len(link) > 0:
|
|
|
+ mat[link[:, 0], link[:, 1]] = 1
|
|
|
+ mat[link[:, 1], link[:, 0]] = 1
|
|
|
+ return mat
|