multitask_learner.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from collections import OrderedDict, defaultdict
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from lcnn.config import M
  7. class MultitaskHead(nn.Module):
  8. def __init__(self, input_channels, num_class):
  9. super(MultitaskHead, self).__init__()
  10. # print("输入的维度是:", input_channels)
  11. m = int(input_channels / 4)
  12. heads = []
  13. for output_channels in sum(M.head_size, []):
  14. heads.append(
  15. nn.Sequential(
  16. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  17. nn.ReLU(inplace=True),
  18. nn.Conv2d(m, output_channels, kernel_size=1),
  19. )
  20. )
  21. self.heads = nn.ModuleList(heads)
  22. assert num_class == sum(sum(M.head_size, []))
  23. def forward(self, x):
  24. return torch.cat([head(x) for head in self.heads], dim=1)
  25. class MultitaskLearner(nn.Module):
  26. def __init__(self, backbone):
  27. super(MultitaskLearner, self).__init__()
  28. self.backbone = backbone
  29. head_size = M.head_size
  30. self.num_class = sum(sum(head_size, []))
  31. self.head_off = np.cumsum([sum(h) for h in head_size])
  32. def forward(self, input_dict):
  33. image = input_dict["image"]
  34. target_b = input_dict["target_b"]
  35. outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"]) # train时aaa是损失,val时是box
  36. result = {"feature": feature}
  37. batch, channel, row, col = outputs[0].shape
  38. # print(f"batch:{batch}")
  39. # print(f"channel:{channel}")
  40. # print(f"row:{row}")
  41. # print(f"col:{col}")
  42. T = input_dict["target"].copy()
  43. n_jtyp = T["junc_map"].shape[1]
  44. # switch to CNHW
  45. for task in ["junc_map"]:
  46. T[task] = T[task].permute(1, 0, 2, 3)
  47. for task in ["junc_offset"]:
  48. T[task] = T[task].permute(1, 2, 0, 3, 4)
  49. offset = self.head_off
  50. loss_weight = M.loss_weight
  51. losses = []
  52. for stack, output in enumerate(outputs):
  53. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  54. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  55. lmap = output[offset[0]: offset[1]].squeeze(0)
  56. # print(f"lmap:{lmap.shape}")
  57. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  58. if stack == 0:
  59. result["preds"] = {
  60. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  61. "lmap": lmap.sigmoid(),
  62. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  63. }
  64. if input_dict["mode"] == "testing":
  65. return result
  66. L = OrderedDict()
  67. L["jmap"] = sum(
  68. cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  69. )
  70. L["lmap"] = (
  71. F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  72. .mean(2)
  73. .mean(1)
  74. )
  75. L["joff"] = sum(
  76. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  77. for i in range(n_jtyp)
  78. for j in range(2)
  79. )
  80. for loss_name in L:
  81. L[loss_name].mul_(loss_weight[loss_name])
  82. losses.append(L)
  83. result["losses"] = losses
  84. result["aaa"] = aaa
  85. return result
  86. def l2loss(input, target):
  87. return ((target - input) ** 2).mean(2).mean(1)
  88. def cross_entropy_loss(logits, positive):
  89. nlogp = -F.log_softmax(logits, dim=0)
  90. return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  91. def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  92. logp = torch.sigmoid(logits) + offset
  93. loss = torch.abs(logp - target)
  94. if mask is not None:
  95. w = mask.mean(2, True).mean(1, True)
  96. w[w == 0] = 1
  97. loss = loss * (mask / w)
  98. return loss.mean(2).mean(1)