1234567891011121314151617181920 |
- def train_epoch(model):
- pass
- def _loss(losses):
- total_loss = 0
- for i in losses.keys():
- if i != "loss_wirepoint":
- total_loss += losses[i]
- else:
- loss_labels = losses[i]["losses"]
- loss_labels_k = list(loss_labels[0].keys())
- for j, name in enumerate(loss_labels_k):
- loss = loss_labels[0][name].mean()
- print(f"{name}:{loss}")
- total_loss += loss
- return total_loss
|