|
@@ -1,5 +1,4 @@
|
|
|
# ??roi_head??????????????
|
|
# ??roi_head??????????????
|
|
|
-import imageio
|
|
|
|
|
from torch.utils.data.dataset import T_co
|
|
from torch.utils.data.dataset import T_co
|
|
|
|
|
|
|
|
from models.base.base_dataset import BaseDataset
|
|
from models.base.base_dataset import BaseDataset
|
|
@@ -73,24 +72,10 @@ class WirePointDataset(BaseDataset):
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
def __getitem__(self, index) -> T_co:
|
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
|
|
|
|
|
- img=imageio.v3.imread(img_path).reshape(512,512,1)
|
|
|
|
|
- img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
|
|
|
|
|
- # 将原始通道复制到第一个通道
|
|
|
|
|
- img_3channel[:, :, 2] = img[:, :, 0]
|
|
|
|
|
-
|
|
|
|
|
- # print(f'dataset img shape:{img.shape}')
|
|
|
|
|
- # img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
- w, h = img.shape[:2]
|
|
|
|
|
-
|
|
|
|
|
- img=torch.from_numpy(img_3channel).permute(2, 0, 1)
|
|
|
|
|
-
|
|
|
|
|
- img=self.zscore_normalize_depth(img)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- # img=img.transpose(2,0,1)
|
|
|
|
|
- # print(f'dataset img shape2:{img.shape}')
|
|
|
|
|
|
|
+ lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
|
|
|
|
|
+ img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
+ w, h = img.size
|
|
|
|
|
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
@@ -105,15 +90,6 @@ class WirePointDataset(BaseDataset):
|
|
|
def __len__(self):
|
|
def __len__(self):
|
|
|
return len(self.imgs)
|
|
return len(self.imgs)
|
|
|
|
|
|
|
|
- def zscore_normalize_depth(self,img):
|
|
|
|
|
- depth = img[2]
|
|
|
|
|
- mean = depth.mean()
|
|
|
|
|
- std = depth.std()
|
|
|
|
|
- depth_normalized = (depth - mean) / (std + 1e-8)
|
|
|
|
|
- img_normalized = img.clone()
|
|
|
|
|
- img_normalized[2] = depth_normalized
|
|
|
|
|
- return img_normalized
|
|
|
|
|
-
|
|
|
|
|
def read_target(self, item, lbl_path, shape, extra=None):
|
|
def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
# print(f'lbl_path:{lbl_path}')
|
|
# print(f'lbl_path:{lbl_path}')
|
|
|
with open(lbl_path, 'r') as file:
|
|
with open(lbl_path, 'r') as file:
|