|
@@ -1,6 +1,7 @@
|
|
import torch
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
+from models.base.base_model import BaseModel
|
|
from models.base.base_trainer import BaseTrainer
|
|
from models.base.base_trainer import BaseTrainer
|
|
from models.config.config_tool import read_yaml
|
|
from models.config.config_tool import read_yaml
|
|
from models.line_detect.dataset_LD import WirePointDataset
|
|
from models.line_detect.dataset_LD import WirePointDataset
|
|
@@ -21,13 +22,23 @@ def _loss(losses):
|
|
total_loss += loss
|
|
total_loss += loss
|
|
|
|
|
|
return total_loss
|
|
return total_loss
|
|
-
|
|
|
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
+def move_to_device(data, device):
|
|
|
|
+ if isinstance(data, (list, tuple)):
|
|
|
|
+ return type(data)(move_to_device(item, device) for item in data)
|
|
|
|
+ elif isinstance(data, dict):
|
|
|
|
+ return {key: move_to_device(value, device) for key, value in data.items()}
|
|
|
|
+ elif isinstance(data, torch.Tensor):
|
|
|
|
+ return data.to(device)
|
|
|
|
+ else:
|
|
|
|
+ return data # 对于非张量类型的数据不做任何改变
|
|
|
|
|
|
class Trainer(BaseTrainer):
|
|
class Trainer(BaseTrainer):
|
|
def __init__(self, model=None,
|
|
def __init__(self, model=None,
|
|
dataset=None,
|
|
dataset=None,
|
|
device='cuda',
|
|
device='cuda',
|
|
**kwargs):
|
|
**kwargs):
|
|
|
|
+
|
|
super().__init__(model,dataset,device,**kwargs)
|
|
super().__init__(model,dataset,device,**kwargs)
|
|
|
|
|
|
def move_to_device(self, data, device):
|
|
def move_to_device(self, data, device):
|
|
@@ -54,15 +65,15 @@ class Trainer(BaseTrainer):
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(f"TensorBoard logging error: {e}")
|
|
print(f"TensorBoard logging error: {e}")
|
|
|
|
|
|
- def train_cfg(self, model, cfg):
|
|
|
|
|
|
+ def train_cfg(self, model:BaseModel, cfg):
|
|
# cfg = r'./config/wireframe.yaml'
|
|
# cfg = r'./config/wireframe.yaml'
|
|
cfg = read_yaml(cfg)
|
|
cfg = read_yaml(cfg)
|
|
print(f'cfg:{cfg}')
|
|
print(f'cfg:{cfg}')
|
|
- print(cfg['model']['n_dyn_negl'])
|
|
|
|
|
|
+ # print(cfg['n_dyn_negl'])
|
|
self.train(model, **cfg)
|
|
self.train(model, **cfg)
|
|
|
|
|
|
- def train(self, model, **cfg):
|
|
|
|
- dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
|
|
|
|
|
|
+ def train(self, model, **kwargs):
|
|
|
|
+ dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
|
|
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
|
|
@@ -71,7 +82,7 @@ class Trainer(BaseTrainer):
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
)
|
|
)
|
|
|
|
|
|
- dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
|
|
|
|
|
|
+ dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
|
|
val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
|
|
@@ -82,15 +93,19 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
# model = linenet_resnet50_fpn().to(self.device)
|
|
# model = linenet_resnet50_fpn().to(self.device)
|
|
|
|
|
|
- optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
|
|
|
|
- writer = SummaryWriter(cfg['io']['logdir'])
|
|
|
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
|
|
|
|
+ writer = SummaryWriter(kwargs['io']['logdir'])
|
|
|
|
|
|
- for epoch in range(cfg['optim']['max_epoch']):
|
|
|
|
|
|
+ for epoch in range(kwargs['optim']['max_epoch']):
|
|
print(f"epoch:{epoch}")
|
|
print(f"epoch:{epoch}")
|
|
model.train()
|
|
model.train()
|
|
|
|
|
|
for imgs, targets in data_loader_train:
|
|
for imgs, targets in data_loader_train:
|
|
- losses = model(self.move_to_device(imgs, self.device), self.move_to_device(targets, self.device))
|
|
|
|
|
|
+ imgs = move_to_device(imgs, device)
|
|
|
|
+ targets=move_to_device(targets,device)
|
|
|
|
+ print(f'imgs:{len(imgs)}')
|
|
|
|
+ print(f'targets:{len(targets)}')
|
|
|
|
+ losses = model(imgs, targets)
|
|
# print(losses)
|
|
# print(losses)
|
|
loss = _loss(losses)
|
|
loss = _loss(losses)
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|