train.py 451 B

1234567891011121314151617181920
  1. def train_epoch(model):
  2. pass
  3. def _loss(losses):
  4. total_loss = 0
  5. for i in losses.keys():
  6. if i != "loss_wirepoint":
  7. total_loss += losses[i]
  8. else:
  9. loss_labels = losses[i]["losses"]
  10. loss_labels_k = list(loss_labels[0].keys())
  11. for j, name in enumerate(loss_labels_k):
  12. loss = loss_labels[0][name].mean()
  13. print(f"{name}:{loss}")
  14. total_loss += loss
  15. return total_loss