Training Utilities

Helpers for environment setup, evaluation, and buffer warm-up.

Atari environments

Atari environment factory helpers.

make_env(env_id: str, n_envs: int, vec_env_cls: VecEnv = <class 'stable_baselines3.common.vec_env.subproc_vec_env.SubprocVecEnv'>, framestack: int = 4, seed: int | None = None, **kwargs) VecEnvWrapper

Create a vectorized Atari environment with optional frame stacking.

Parameters:
  • env_id – Gymnasium environment identifier.

  • n_envs – Number of parallel environments.

  • vec_env_cls – Vectorized environment class (defaults to subprocess workers).

  • framestack – Number of frames to stack; values below 2 disable stacking.

  • seed – Optional random seed passed to the environment factory.

  • **kwargs – Additional keyword arguments forwarded to make_atari_env.

Returns:

A vectorized environment, transposed for channel-first image policies.

Evaluation

Roll out a policy for evaluation and optional buffer filling.

process_outcome(infos: list[dict]) tuple[ndarray[float], ndarray[bool]]

Extract episode rewards and done flags from vectorized env info dicts.

Parameters:

infos – Info dicts returned by VecEnv.step.

Returns:

Episode return estimates (nan when unavailable) and a boolean done mask.

eval_model(n_eps: int, eval_env: VecEnv, model: BaseAlgorithm, close_env: bool = True, buffer: BaseBuffer | BaseRecordBuffer | Iterable[BaseBuffer | BaseRecordBuffer] | None = None) tuple[list[int | float | integer | floating], list]

Run evaluation episodes and optionally fill replay buffers.

Parameters:
  • n_eps – Number of completed episodes to collect.

  • eval_env – Vectorized evaluation environment.

  • model – Policy used for action selection.

  • close_env – Whether to close eval_env when finished.

  • buffer – Optional replay buffer(s) to fill with transitions.

Returns:

Episode returns and per-buffer time spent in add calls.

Buffer warm-up

Fill replay buffers before training (buffer warm-up).

warm_up(buffer: BaseBuffer | BaseRecordBuffer | Iterable[BaseBuffer | BaseRecordBuffer], n_envs: int, warmup_env: VecEnv, warmup_model: BaseAlgorithm, warmup_episodes: int | None = None, mean_ep_len: int | float | None = None) tuple[list[int | float | integer | floating], list[int | float | integer | floating]]

Perform buffer warm up with set model.