Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
minor training improvements
  • Loading branch information
n-poulsen committed Nov 6, 2024
commit 751fd373dc24293eea93cdcc845fdd79a59efa7a
6 changes: 5 additions & 1 deletion deeplabcut/core/metrics/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
from __future__ import annotations

from unittest.mock import Mock, patch

import numpy as np

try:
Expand All @@ -26,6 +28,8 @@
with_pycocotools = False


@patch("pycocotools.coco.print", Mock())
@patch("pycocotools.cocoeval.print", Mock())
def compute_bbox_metrics(
ground_truth: dict[str, dict],
detections: dict[str, dict],
Expand Down Expand Up @@ -150,6 +154,6 @@ def _get_metric(
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
mean_s = 100 * np.mean(s[s > -1])

return f"{metric_name}@{thresh}", mean_s
6 changes: 3 additions & 3 deletions deeplabcut/pose_estimation_pytorch/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train(
device: str | None = "cpu",
gpus: list[int] | None = None,
logger_config: dict | None = None,
snapshot_path: str | None = None,
snapshot_path: str | Path | None = None,
transform: A.BaseCompose | None = None,
inference_transform: A.BaseCompose | None = None,
max_snapshots_to_keep: int | None = None,
Expand Down Expand Up @@ -197,8 +197,8 @@ def train_network(
trainingsetindex: int = 0,
modelprefix: str = "",
device: str | None = None,
snapshot_path: str | None = None,
detector_path: str | None = None,
snapshot_path: str | Path | None = None,
detector_path: str | Path | None = None,
batch_size: int | None = None,
epochs: int | None = None,
save_epochs: int | None = None,
Expand Down
6 changes: 4 additions & 2 deletions deeplabcut/pose_estimation_pytorch/runners/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,17 @@ def log_config(self, config: dict = None) -> None:
class CSVLogger(BaseLogger):
"""Logger saving stats and metrics to a CSV file"""

def __init__(self, train_folder: Path) -> None:
def __init__(self, train_folder: Path, log_filename: str) -> None:
"""Initialize the WandbLogger class.

Args:
train_folder: The path of the folder containing training files.
log_filename: The name of the file in which to store training stats
"""
super().__init__()
self.train_folder = train_folder
self.log_file = train_folder / "learning_stats.csv"
self.log_filename = log_filename
self.log_file = train_folder / log_filename

self._steps: list[int] = []
self._metric_store: list[dict] = []
Expand Down
50 changes: 31 additions & 19 deletions deeplabcut/pose_estimation_pytorch/runners/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def __init__(
device: str = "cpu",
gpus: list[int] | None = None,
eval_interval: int = 1,
snapshot_path: Path | None = None,
snapshot_path: str | Path | None = None,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
logger: BaseLogger | None = None,
log_filename: str = "learning_stats.csv",
):
"""
Args:
Expand All @@ -66,6 +67,7 @@ def __init__(
pretrained weights
scheduler: scheduler for adjusting the lr of the optimizer
logger: logger to monitor training (e.g WandB logger)
log_filename: name of the file in which to store training stats
"""
super().__init__(
model=model, device=device, gpus=gpus, snapshot_path=snapshot_path
Expand All @@ -75,7 +77,9 @@ def __init__(
self.scheduler = scheduler
self.snapshot_manager = snapshot_manager
self.history: dict[str, list] = dict(train_loss=[], eval_loss=[])
self.csv_logger = CSVLogger(train_folder=snapshot_manager.model_folder)
self.csv_logger = CSVLogger(
train_folder=snapshot_manager.model_folder, log_filename=log_filename,
)
self.logger = logger
self.starting_epoch = 0
self.current_epoch = 0
Expand Down Expand Up @@ -157,7 +161,7 @@ def fit(
self.logger.select_images_to_log(train_loader, valid_loader)

# continuing to train a model: either total epochs or extra epochs
if self.starting_epoch > epochs:
if self.starting_epoch > 0:
epochs = self.starting_epoch + epochs

for e in range(self.starting_epoch + 1, epochs + 1):
Expand All @@ -183,6 +187,13 @@ def fit(
self.snapshot_manager.update(e, self.state_dict(), last=(e == epochs))
logging.info(msg)

epoch_metrics = self._metadata.get("metrics")
if epoch_metrics is not None and len(epoch_metrics) > 0:
logging.info(f"Model performance:")
line_length = max([len(name) for name in epoch_metrics.keys()]) + 2
for name, score in epoch_metrics.items():
logging.info(f" {(name + ':').ljust(line_length)}{score:6.2f}")

def _epoch(
self,
loader: torch.utils.data.DataLoader,
Expand Down Expand Up @@ -215,31 +226,28 @@ def _epoch(
loss_metrics = defaultdict(list)
for i, batch in enumerate(loader):
losses_dict = self.step(batch, mode)
epoch_loss.append(losses_dict["total_loss"])
if "total_loss" in losses_dict:
epoch_loss.append(losses_dict["total_loss"])
if (i + 1) % display_iters == 0 and mode != "eval":
logging.info(
f"Number of iterations: {i + 1}, "
f"loss: {losses_dict['total_loss']:.5f}, "
f"lr: {self.optimizer.param_groups[0]['lr']}"
)

for key in losses_dict.keys():
loss_metrics[key].append(losses_dict[key])

if (i + 1) % display_iters == 0:
logging.info(
f"Number of iterations: {i + 1}, "
f"loss: {losses_dict['total_loss']:.5f}, "
f"lr: {self.optimizer.param_groups[0]['lr']}"
)

perf_metrics = None
if mode == "eval":
perf_metrics = self._compute_epoch_metrics()
self._metadata["metrics"] = perf_metrics
self._epoch_predictions = {}
self._epoch_ground_truth = {}
if perf_metrics is not None and len(perf_metrics) > 0:
logging.info(f"Epoch {self.current_epoch} performance:")
for name, score in perf_metrics.items():
logging.info(f"{name + ':': <20}{score:.3f}")

epoch_loss = np.mean(epoch_loss).item()
self.history[f"{mode}_loss"].append(epoch_loss)
if len(epoch_loss) > 0:
epoch_loss = np.mean(epoch_loss).item()
self.history[f"{mode}_loss"].append(epoch_loss)

metrics_to_log = {}
if perf_metrics:
Expand Down Expand Up @@ -411,7 +419,11 @@ def __init__(self, model: BaseDetector, optimizer: torch.optim.Optimizer, **kwar
optimizer: The optimizer to use to train the model.
**kwargs: TrainingRunner kwargs
"""
super().__init__(model, optimizer, **kwargs)
log_filename = "learning_stats_detector.csv"
if "log_filename" in kwargs:
log_filename = kwargs.pop("log_filename")

super().__init__(model, optimizer, log_filename=log_filename, **kwargs)
self._pycoco_warning_displayed = False
self._print_valid_loss = False

Expand Down Expand Up @@ -553,7 +565,7 @@ def build_training_runner(
model: nn.Module,
device: str,
gpus: list[int] | None = None,
snapshot_path: str | None = None,
snapshot_path: str | Path | None = None,
logger: BaseLogger | None = None,
) -> TrainingRunner:
"""
Expand Down