|
@@ -10,7 +10,7 @@ import os
|
|
|
import random
|
|
import random
|
|
|
import cv2
|
|
import cv2
|
|
|
import PIL
|
|
import PIL
|
|
|
-
|
|
|
|
|
|
|
+import imageio
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib as mpl
|
|
import matplotlib as mpl
|
|
|
from torchvision.utils import draw_bounding_boxes
|
|
from torchvision.utils import draw_bounding_boxes
|
|
@@ -33,10 +33,11 @@ def validate_keypoints(keypoints, image_width, image_height):
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineDataset(BaseDataset):
|
|
class LineDataset(BaseDataset):
|
|
|
- def __init__(self, dataset_path, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
|
|
|
|
|
+ def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
|
super().__init__(dataset_path)
|
|
super().__init__(dataset_path)
|
|
|
|
|
|
|
|
self.data_path = dataset_path
|
|
self.data_path = dataset_path
|
|
|
|
|
+ self.data_type = data_type
|
|
|
print(f'data_path:{dataset_path}')
|
|
print(f'data_path:{dataset_path}')
|
|
|
self.transforms = transforms
|
|
self.transforms = transforms
|
|
|
self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
|
|
self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
|
|
@@ -49,11 +50,18 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
def __getitem__(self, index) -> T_co:
|
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
-
|
|
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
- w, h = img.size
|
|
|
|
|
-
|
|
|
|
|
|
|
+ if self.data_type == 'tiff':
|
|
|
|
|
+ lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
|
|
|
|
|
+ img = imageio.v3.imread(img_path).reshape(512, 512, 1)
|
|
|
|
|
+ img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
|
|
|
|
|
+ img_3channel[:, :, 2] = img[:, :, 0]
|
|
|
|
|
+
|
|
|
|
|
+ w, h = img.shape[:2]
|
|
|
|
|
+ img = torch.from_numpy(img_3channel).permute(2, 0, 1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
+ img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
+ w, h = img.size
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
if self.transforms:
|
|
if self.transforms:
|