|
@@ -102,27 +102,29 @@ def c(x):
|
|
|
|
|
|
|
|
|
def show_line(img, pred, epoch, writer):
|
|
|
- im = img.permute(1, 2, 0)
|
|
|
+ im = img.permute(1, 2, 0) # [512, 512, 3]
|
|
|
writer.add_image("ori", im, epoch, dataformats="HWC")
|
|
|
|
|
|
boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
|
|
|
colors="yellow", width=1)
|
|
|
writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+ print(f'box:{pred["boxes"][:5,:]}')
|
|
|
+ print(f'line:{pred["keypoints"][:5,:]}')
|
|
|
|
|
|
PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
# H = pred[1]['wires']
|
|
|
lines = pred["keypoints"].detach().cpu().numpy()
|
|
|
scores = pred["keypoints_scores"].detach().cpu().numpy()
|
|
|
- for i in range(1, len(lines)):
|
|
|
- if (lines[i] == lines[0]).all():
|
|
|
- lines = lines[:i]
|
|
|
- scores = scores[:i]
|
|
|
- break
|
|
|
+ # for i in range(1, len(lines)):
|
|
|
+ # if (lines[i] == lines[0]).all():
|
|
|
+ # lines = lines[:i]
|
|
|
+ # scores = scores[:i]
|
|
|
+ # break
|
|
|
|
|
|
# postprocess lines to remove overlapped lines
|
|
|
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
|
|
|
- print(f'nscores:{nscores}')
|
|
|
+ # print(f'nscores:{nscores}')
|
|
|
|
|
|
for i, t in enumerate([0.5]):
|
|
|
plt.gca().set_axis_off()
|
|
@@ -184,6 +186,10 @@ def evaluate(model, data_loader, epoch, writer, device):
|
|
|
if batch_idx == 0:
|
|
|
show_line(images[0], outputs[0], epoch, writer)
|
|
|
|
|
|
+ # print(f'outputs:{outputs}')
|
|
|
+ # print(f'outputs[0]:{outputs[0]}')
|
|
|
+
|
|
|
+
|
|
|
# outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
|
|
|
# model_time = time.time() - model_time
|
|
|
#
|
|
@@ -278,7 +284,6 @@ def train(model, **kwargs):
|
|
|
dataset_type='val')
|
|
|
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset)
|
|
|
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
test_sampler = torch.utils.data.RandomSampler(dataset_test)
|
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn
|
|
@@ -289,6 +294,7 @@ def train(model, **kwargs):
|
|
|
dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
|
|
|
)
|
|
|
|
|
|
+
|
|
|
img_results_path = os.path.join(train_result_ptath, 'img_results')
|
|
|
if os.path.exists(train_result_ptath):
|
|
|
pass
|
|
@@ -310,7 +316,7 @@ def train(model, **kwargs):
|
|
|
if os.path.exists(f'{wts_path}/last.pt'):
|
|
|
os.remove(f'{wts_path}/last.pt')
|
|
|
torch.save(model.state_dict(), f'{wts_path}/last.pt')
|
|
|
- write_metric_logs(epoch, metric_logger, writer)
|
|
|
+ # write_metric_logs(epoch, metric_logger, writer)
|
|
|
if epoch == 0:
|
|
|
best_loss = losses;
|
|
|
if best_loss >= losses:
|
|
@@ -320,6 +326,9 @@ def train(model, **kwargs):
|
|
|
torch.save(model.state_dict(), f'{wts_path}/best.pt')
|
|
|
|
|
|
evaluate(model, data_loader_test, epoch, writer, device=device)
|
|
|
+ avg_train_loss = total_train_loss / len(data_loader)
|
|
|
+
|
|
|
+ writer.add_scalar('Loss/train', avg_train_loss, epoch)
|
|
|
|
|
|
|
|
|
def get_transform(is_train, **kwargs):
|