|
@@ -26,7 +26,6 @@ from models.wirenet.head import RoIHeads
|
|
|
from models.wirenet.wirepoint_dataset import WirePointDataset
|
|
|
from tools import utils
|
|
|
|
|
|
-
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
|
|
|
@@ -119,7 +118,7 @@ class WirepointRCNN(FasterRCNN):
|
|
|
|
|
|
if wirepoint_roi_pool is None:
|
|
|
wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
|
|
|
- sampling_ratio=2,)
|
|
|
+ sampling_ratio=2, )
|
|
|
|
|
|
if wirepoint_head is None:
|
|
|
keypoint_layers = tuple(512 for _ in range(8))
|
|
@@ -283,7 +282,7 @@ class WirepointPredictor(nn.Module):
|
|
|
)
|
|
|
self.loss = nn.BCEWithLogitsLoss(reduction="none")
|
|
|
|
|
|
- def forward(self, inputs,features, targets=None):
|
|
|
+ def forward(self, inputs, features, targets=None):
|
|
|
|
|
|
# outputs, features = input
|
|
|
# for out in outputs:
|
|
@@ -315,25 +314,24 @@ class WirepointPredictor(nn.Module):
|
|
|
else:
|
|
|
self.training = False
|
|
|
t = {
|
|
|
- "junc_coords": torch.zeros(1, 2),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
- }
|
|
|
- wires_targets=[t for b in range(inputs.size(0))]
|
|
|
-
|
|
|
- wires_meta={
|
|
|
+ "junc_coords": torch.zeros(1, 2),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
"junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
"junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
}
|
|
|
+ wires_targets = [t for b in range(inputs.size(0))]
|
|
|
|
|
|
+ wires_meta = {
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
+ }
|
|
|
|
|
|
T = wires_meta.copy()
|
|
|
n_jtyp = T["junc_map"].shape[1]
|
|
|
offset = self.head_off
|
|
|
- result={}
|
|
|
+ result = {}
|
|
|
for stack, output in enumerate([inputs]):
|
|
|
output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
|
|
|
print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
|
|
@@ -396,8 +394,8 @@ class WirepointPredictor(nn.Module):
|
|
|
+ x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
+ x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
)
|
|
|
- .reshape(n_channel, -1, self.n_pts0)
|
|
|
- .permute(1, 0, 2)
|
|
|
+ .reshape(n_channel, -1, self.n_pts0)
|
|
|
+ .permute(1, 0, 2)
|
|
|
)
|
|
|
xp = self.pooling(xp)
|
|
|
print(f'xp.shape:{xp.shape}')
|
|
@@ -419,13 +417,11 @@ class WirepointPredictor(nn.Module):
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
|
|
|
-
|
|
|
# if mode != "training":
|
|
|
# self.inference(x, idx, jcs, n_batch, ps)
|
|
|
|
|
|
# return result
|
|
|
|
|
|
-
|
|
|
####deprecated
|
|
|
# def inference(self,input, idx, jcs, n_batch, ps):
|
|
|
# if not self.training:
|
|
@@ -565,7 +561,6 @@ class WirepointPredictor(nn.Module):
|
|
|
return line, label.float(), feat, jcs
|
|
|
|
|
|
|
|
|
-
|
|
|
def wirepointrcnn_resnet50_fpn(
|
|
|
*,
|
|
|
weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
|
|
@@ -596,6 +591,21 @@ def wirepointrcnn_resnet50_fpn(
|
|
|
return model
|
|
|
|
|
|
|
|
|
+def _loss(losses):
|
|
|
+ total_loss = 0
|
|
|
+ for i in losses.keys():
|
|
|
+ if i != "loss_wirepoint":
|
|
|
+ total_loss += losses[i]
|
|
|
+ else:
|
|
|
+ loss_labels = losses[i]["losses"]
|
|
|
+ loss_labels_k = list(loss_labels[0].keys())
|
|
|
+ for j, name in enumerate(loss_labels_k):
|
|
|
+ loss = loss_labels[0][name].mean()
|
|
|
+ total_loss += loss
|
|
|
+
|
|
|
+ return total_loss
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
cfg = 'wirenet.yaml'
|
|
|
cfg = read_yaml(cfg)
|
|
@@ -603,30 +613,73 @@ if __name__ == '__main__':
|
|
|
print(cfg['model']['n_dyn_negl'])
|
|
|
# net = WirepointPredictor()
|
|
|
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ device_name = "cuda"
|
|
|
+ torch.backends.cudnn.deterministic = True
|
|
|
+ torch.cuda.manual_seed(0)
|
|
|
+ print("Let's use", torch.cuda.device_count(), "GPU(s)!")
|
|
|
+ else:
|
|
|
+ print("CUDA is not available")
|
|
|
+
|
|
|
+ device = torch.device(device_name)
|
|
|
+
|
|
|
+
|
|
|
dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset)
|
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader = torch.utils.data.DataLoader(
|
|
|
- dataset, batch_sampler=train_batch_sampler, num_workers=10, collate_fn=train_collate_fn
|
|
|
+ dataset, batch_sampler=train_batch_sampler, num_workers=4, collate_fn=train_collate_fn
|
|
|
)
|
|
|
- model = wirepointrcnn_resnet50_fpn()
|
|
|
+ model = wirepointrcnn_resnet50_fpn().to(device)
|
|
|
|
|
|
- for i in cfg['optim']['max_epoch']:
|
|
|
- model.train()
|
|
|
- imgs, targets = next(iter(data_loader))
|
|
|
- pred = model(imgs, targets)
|
|
|
+ optimizer = torch.optim.SGD(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
|
|
|
- # imgs, targets = next(iter(data_loader))
|
|
|
- #
|
|
|
- # model.train()
|
|
|
- # pred = model(imgs, targets)
|
|
|
- # print(f'pred:{pred}')
|
|
|
|
|
|
- # result, losses = model(imgs, targets)
|
|
|
- # print(f'result:{result}')
|
|
|
- # print(f'pred:{losses}')
|
|
|
+ def move_to_device(data, device):
|
|
|
+ if isinstance(data, (list, tuple)):
|
|
|
+ return type(data)(move_to_device(item, device) for item in data)
|
|
|
+ elif isinstance(data, dict):
|
|
|
+ return {key: move_to_device(value, device) for key, value in data.items()}
|
|
|
+ elif isinstance(data, torch.Tensor):
|
|
|
+ return data.to(device)
|
|
|
+ else:
|
|
|
+ return data # 对于非张量类型的数据不做任何改变
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(cfg['optim']['max_epoch']):
|
|
|
+ model.train()
|
|
|
+ # imgs, targets = next(iter(data_loader))
|
|
|
+ # loss = model(imgs, targets)
|
|
|
+ # print(loss)
|
|
|
+ # losses = _loss(loss)
|
|
|
+ # optimizer.zero_grad()
|
|
|
+ # loss.backward()
|
|
|
+ # optimizer.step()
|
|
|
+
|
|
|
+ for imgs, targets in data_loader:
|
|
|
+ losses = model(move_to_device(imgs, device), move_to_device(targets, device))
|
|
|
+ print(losses)
|
|
|
+ loss = _loss(losses)
|
|
|
+ print(loss)
|
|
|
+ # 优化器优化模型
|
|
|
+ optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+
|
|
|
+# imgs, targets = next(iter(data_loader))
|
|
|
+#
|
|
|
+# model.train()
|
|
|
+# pred = model(imgs, targets)
|
|
|
+# print(f'pred:{pred}')
|
|
|
+
|
|
|
+# result, losses = model(imgs, targets)
|
|
|
+# print(f'result:{result}')
|
|
|
+# print(f'pred:{losses}')
|
|
|
'''
|
|
|
########### predict#############
|
|
|
|