deepdrivemd.selection.latest.select_model

Functions

get_model_path([stage_idx, task_idx, api, ...])

Get the current best model.

latest_checkpoint(api[, checkpoint_dir, ...])

Select latest PyTorch model checkpoint.

latest_model_checkpoint(cfg)

Select the latest model checkpoint and write path to JSON.

deepdrivemd.selection.latest.select_model.get_model_path(stage_idx: int = - 1, task_idx: int = 0, api: Optional[deepdrivemd.data.api.DeepDriveMD_API] = None, experiment_dir: Optional[Union[str, pathlib.Path]] = None) Optional[Tuple[pathlib.Path, pathlib.Path]]

Get the current best model.

Should be imported by other stages to retrieve the best model path.

Parameters
  • api (DeepDriveMD_API, optional) – API to DeepDriveMD to access the machine learning model path.

  • experiment_dir (Union[str, Path], optional) – Experiment directory to initialize DeepDriveMD_API.

Returns

  • None – If model selection has not run before.

  • model_config (Path, optional) – Path to the most recent model YAML configuration file selected by the model selection stage. Contains hyperparameters.

  • model_checkpoint (Path, optional) – Path to the most recent model weights selected by the model selection stage.

Raises

ValueError – If both api and experiment_dir are None.

deepdrivemd.selection.latest.select_model.latest_checkpoint(api: deepdrivemd.data.api.DeepDriveMD_API, checkpoint_dir: str = 'checkpoint', checkpoint_suffix: str = '.pt') pathlib.Path

Select latest PyTorch model checkpoint.

Assuming the model outputs a checkpoint_dir directory with checkpoint_suffix checkpoint files with the form XXX_<epoch-index>_YYY_ZZZ…<checkpoint_suffix>, return the path to the latest training epoch model checkpoint.

Parameters
  • api (DeepDriveMD_API) – API to DeepDriveMD to access the machine learning model path.

  • checkpoint_dir (str, default=”checkpoint”) – Name of the checkpoint directory inside the model path. Note, if checkpoint files are stored in the top level directory, set checkpoint_dir=””.

  • checkpoint_suffix (str, default=”.pt”) – The file extension for checkpoint files (.pt, .h5, etc).

Returns

Path – Path to the latest model checkpoint file.

deepdrivemd.selection.latest.select_model.latest_model_checkpoint(cfg: deepdrivemd.selection.latest.config.LatestCheckpointConfig) None

Select the latest model checkpoint and write path to JSON.

Find the latest model checkpoint written by the machine learning stage and write the path into a JSON file to be consumed by the agent stage.

Parameters

cfg (LatestCheckpointConfig) – pydantic YAML configuration for model selection task.