multitask_learner.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from collections import OrderedDict, defaultdict
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from lcnn.config import M
  7. class MultitaskHead(nn.Module):
  8. def __init__(self, input_channels, num_class):
  9. super(MultitaskHead, self).__init__()
  10. # print("输入的维度是:", input_channels)
  11. m = int(input_channels / 4)
  12. heads = []
  13. for output_channels in sum(M.head_size, []):
  14. heads.append(
  15. nn.Sequential(
  16. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  17. nn.ReLU(inplace=True),
  18. nn.Conv2d(m, output_channels, kernel_size=1),
  19. )
  20. )
  21. self.heads = nn.ModuleList(heads)
  22. assert num_class == sum(sum(M.head_size, []))
  23. def forward(self, x):
  24. return torch.cat([head(x) for head in self.heads], dim=1)
  25. class MultitaskLearner(nn.Module):
  26. def __init__(self, backbone):
  27. super(MultitaskLearner, self).__init__()
  28. self.backbone = backbone
  29. head_size = M.head_size
  30. self.num_class = sum(sum(head_size, []))
  31. self.head_off = np.cumsum([sum(h) for h in head_size])
  32. def forward(self, input_dict):
  33. image = input_dict["image"]
  34. target_b = input_dict["target_b"]
  35. # if input_dict["mode"] == "training":
  36. # outputs, feature, aaa = self.backbone(image, input_dict["mode"], target_b) # train时aaa是损失,val时是box
  37. #
  38. # else: # Inference mode
  39. # outputs, feature, aaa = self.backbone(image, input_dict["mode"]) # train时aaa是损失,val时是box
  40. outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"]) # train时aaa是损失,val时是box
  41. result = {"feature": feature}
  42. result["aaa"] = aaa
  43. batch, channel, row, col = outputs[0].shape
  44. T = input_dict["target"].copy()
  45. n_jtyp = T["junc_map"].shape[1]
  46. # switch to CNHW
  47. for task in ["junc_map"]:
  48. T[task] = T[task].permute(1, 0, 2, 3)
  49. for task in ["junc_offset"]:
  50. T[task] = T[task].permute(1, 2, 0, 3, 4)
  51. offset = self.head_off
  52. loss_weight = M.loss_weight
  53. losses = []
  54. for stack, output in enumerate(outputs):
  55. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  56. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  57. lmap = output[offset[0]: offset[1]].squeeze(0)
  58. # print(f"lmap:{lmap.shape}")
  59. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  60. if stack == 0:
  61. result["preds"] = {
  62. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  63. "lmap": lmap.sigmoid(),
  64. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  65. }
  66. if input_dict["mode"] == "testing":
  67. # result["aaa"] = aaa
  68. return result
  69. L = OrderedDict()
  70. L["jmap"] = sum(
  71. cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  72. )
  73. L["lmap"] = (
  74. F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  75. .mean(2)
  76. .mean(1)
  77. )
  78. L["joff"] = sum(
  79. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  80. for i in range(n_jtyp)
  81. for j in range(2)
  82. )
  83. for loss_name in L:
  84. L[loss_name].mul_(loss_weight[loss_name])
  85. losses.append(L)
  86. result["losses"] = losses
  87. # result["aaa"] = aaa
  88. return result
  89. def l2loss(input, target):
  90. return ((target - input) ** 2).mean(2).mean(1)
  91. def cross_entropy_loss(logits, positive):
  92. nlogp = -F.log_softmax(logits, dim=0)
  93. return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  94. def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  95. logp = torch.sigmoid(logits) + offset
  96. loss = torch.abs(logp - target)
  97. if mask is not None:
  98. w = mask.mean(2, True).mean(1, True)
  99. w[w == 0] = 1
  100. loss = loss * (mask / w)
  101. return loss.mean(2).mean(1)