def train_epoch(model):
    pass


def _loss(losses):
    total_loss = 0
    for i in losses.keys():
        if i != "loss_wirepoint":
            total_loss += losses[i]
        else:
            loss_labels = losses[i]["losses"]

    loss_labels_k = list(loss_labels[0].keys())
    for j, name in enumerate(loss_labels_k):
        loss = loss_labels[0][name].mean()
        print(f"{name}:{loss}")
        total_loss += loss

    return total_loss