|
@@ -105,18 +105,23 @@ class WirePointDataset(BaseDataset):
|
|
}
|
|
}
|
|
|
|
|
|
labels = []
|
|
labels = []
|
|
- if self.target_type == 'polygon':
|
|
|
|
- labels, masks = read_masks_from_txt_wire(lbl_path, shape)
|
|
|
|
- elif self.target_type == 'pixel':
|
|
|
|
- labels = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
|
|
|
+ #
|
|
|
|
+ # if self.target_type == 'polygon':
|
|
|
|
+ # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
|
|
|
|
+ # elif self.target_type == 'pixel':
|
|
|
|
+ # labels = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
|
|
|
# print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
# print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
target = {}
|
|
target = {}
|
|
- target["labels"] = torch.stack(labels)
|
|
|
|
|
|
+ # target["labels"] = torch.stack(labels)
|
|
|
|
+
|
|
|
|
+
|
|
target["image_id"] = torch.tensor(item)
|
|
target["image_id"] = torch.tensor(item)
|
|
# return wire_labels, target
|
|
# return wire_labels, target
|
|
target["wires"] = wire_labels
|
|
target["wires"] = wire_labels
|
|
target["boxes"] = line_boxes(target)
|
|
target["boxes"] = line_boxes(target)
|
|
|
|
+ target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
|
|
|
|
+ # print(f'target["labels"]:{ target["labels"]}')
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
return target
|
|
return target
|
|
|
|
|