Compressed Buffers
Compressed rollout and replay buffers for Stable-Baselines3.
- class CompressedRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None)
Bases:
RolloutBuffer,BaseCompressedBufferRolloutBuffer, but compressed!
Create a compressed rollout buffer.
- Parameters:
buffer_size – Number of steps collected per environment before rollout ends.
observation_space – Gymnasium observation space.
action_space – Gymnasium action space.
device – Torch device used when sampling batches.
gae_lambda – GAE lambda for advantage estimation.
gamma – Discount factor for returns.
n_envs – Number of parallel environments.
dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
- add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a rollout step with a compressed observation.
- Parameters:
obs – Observation batch from the environment.
action – Action batch.
reward – Reward batch.
episode_start – Whether each environment started a new episode.
value – Value estimate for the current state under the policy.
log_prob – Log probability of the action under the policy.
- class CompressedDictRolloutBuffer(buffer_size: int, observation_space: Dict, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None)
Bases:
CompressedRolloutBufferDictRolloutBuffer, but compressed!
Create a compressed rollout buffer for dictionary observations.
- Parameters:
buffer_size – Number of steps collected per environment before rollout ends.
observation_space – Gymnasium
Dictobservation space.action_space – Gymnasium action space.
device – Torch device used when sampling batches.
gae_lambda – GAE lambda for advantage estimation.
gamma – Discount factor for returns.
n_envs – Number of parallel environments.
dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
- add(obs: dict[str, ndarray], action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a dict rollout step with compressed observations per key.
- Parameters:
obs – Observation dict from the environment.
action – Action batch.
reward – Reward batch.
episode_start – Whether each environment started a new episode.
value – Value estimate for the current state under the policy.
log_prob – Log probability of the action under the policy.
- get(batch_size: int | None = None) Generator[DictRolloutBufferSamples, None, None]
Yield shuffled dict rollout minibatches after the buffer is full.
- Parameters:
batch_size – Minibatch size. When
None, the full flattened buffer is used.- Yields:
Batches of dict rollout samples with decompressed observations.
- class CompressedReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw')
Bases:
ReplayBuffer,BaseCompressedBufferReplayBuffer, but compressed!
Create a compressed replay buffer for vector or image observations.
- Parameters:
buffer_size – Maximum number of transitions per environment.
observation_space – Gymnasium observation space.
action_space – Gymnasium action space.
device – Torch device used when sampling batches.
n_envs – Number of parallel environments.
optimize_memory_usage – Reuse observation slots for next observations.
handle_timeout_termination – Store timeout flags from
TimeLimit.truncated.dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
output_dtype – Sample dtype for observations (
"raw"keeps storage dtype).
- add(obs: ndarray, next_obs: ndarray, action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None
Add a transition, compressing observations before storage.
- Parameters:
obs – Current observation batch.
next_obs – Next observation batch.
action – Action batch.
reward – Reward batch.
done – Episode termination flags.
infos – Per-environment info dicts from the vectorized environment.
- class CompressedDictReplayBuffer(buffer_size: int, observation_space: Dict, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw')
Bases:
CompressedReplayBufferDictReplayBuffer, but compressed!
Create a compressed replay buffer for dictionary observations.
- Parameters:
buffer_size – Maximum number of transitions per environment.
observation_space – Gymnasium
Dictobservation space.action_space – Gymnasium action space.
device – Torch device used when sampling batches.
n_envs – Number of parallel environments.
optimize_memory_usage – Must be
Falsefor dict observations.handle_timeout_termination – Store timeout flags from
TimeLimit.truncated.dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
output_dtype – Sample dtype for observations (
"raw"keeps storage dtype).
- add(obs: dict[str, ndarray], next_obs: dict[str, ndarray], action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None
Add a dict observation transition, compressing each key separately.
- Parameters:
obs – Current observation dict.
next_obs – Next observation dict.
action – Action batch.
reward – Reward batch.
done – Episode termination flags.
infos – Per-environment info dicts from the vectorized environment.
- class CompressedArray(shape: int | tuple | Any, dtype: integer | floating, obs_shape: int | tuple | Any, buffer: Any | None = None, offset: Any = 0, strides: Any | None = None, order: Literal[None, 'K', 'A', 'C', 'F'] = None, dtypes: dict | None = None, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, **kwargs)
Bases:
ndarray,BaseCompressedBufferExperimental Compressed Array Class.
Initialize compression settings for this array view.
- Parameters:
shape – Storage shape for compressed byte objects.
dtype – Element dtype of reconstructed observations.
obs_shape – Original observation shape before flattening.
buffer – Optional underlying buffer passed to
np.ndarray.offset – Byte offset into
buffer.strides – Stride tuple passed to
np.ndarray.order – Memory layout order passed to
np.ndarray.dtypes – Element and run-length dtypes; inferred when omitted.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
**kwargs – Additional arguments forwarded to the ndarray base.
- class DummyCls(*args, **kwargs)
Bases:
objectPlaceholder type used when optional compression backends are unavailable.
Accept arbitrary arguments and perform no initialization.
- find_smallest_dtype(max_val: int, signed: bool = False, fallback: dtype = <class 'numpy.float32'>) dtype
Find smallest dtype for runs_type.
- init_jit(*args, **kwargs)
Raise when Numba is not installed.
- find_buffer_dtypes(obs_shape: int | tuple, elem_dtype: integer | floating = <class 'numpy.uint8'>, compression_method: str = 'rle') dict[str, Any]
Find the best data types to use for CompressedBuffer based on obs shape and compression method.
Core classes
|
ReplayBuffer, but compressed! |
|
DictReplayBuffer, but compressed! |
|
RolloutBuffer, but compressed! |
|
DictRolloutBuffer, but compressed! |
|
Experimental Compressed Array Class. |
Helpers
|
Find the best data types to use for CompressedBuffer based on obs shape and compression method. |
|
Raise when Numba is not installed. |
|
Find smallest dtype for runs_type. |
Return whether the igzip backend is available. |
|
Return whether the Numba RLE backend is available. |
Implementation modules
Base classes and helpers for compressed observation storage.
- init_jit(*args, **kwargs)
Raise when Numba is not installed.
- find_buffer_dtypes(obs_shape: int | tuple, elem_dtype: integer | floating = <class 'numpy.uint8'>, compression_method: str = 'rle') dict[str, Any]
Find the best data types to use for CompressedBuffer based on obs shape and compression method.
- class BaseCompressedBuffer(compression_method: str | None = None, compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, flatten_config: dict | None = None)
Bases:
objectBase Compressed Buffer Class.
Configure compression and decompression callables.
- Parameters:
compression_method – Registered method name (for example
"rle"or"gzip"). WhenNone, compression is not configured.compression_kwargs – Keyword arguments passed to the compressor.
decompression_kwargs – Keyword arguments passed to the decompressor.
flatten_config – Shape and dtype used when reconstructing flattened observations.
- class DummyCls(*args, **kwargs)
Bases:
objectPlaceholder type used when optional compression backends are unavailable.
Accept arbitrary arguments and perform no initialization.
Replay buffers that store compressed observations.
- class CompressedReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw')
Bases:
ReplayBuffer,BaseCompressedBufferReplayBuffer, but compressed!
Create a compressed replay buffer for vector or image observations.
- Parameters:
buffer_size – Maximum number of transitions per environment.
observation_space – Gymnasium observation space.
action_space – Gymnasium action space.
device – Torch device used when sampling batches.
n_envs – Number of parallel environments.
optimize_memory_usage – Reuse observation slots for next observations.
handle_timeout_termination – Store timeout flags from
TimeLimit.truncated.dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
output_dtype – Sample dtype for observations (
"raw"keeps storage dtype).
- add(obs: ndarray, next_obs: ndarray, action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None
Add a transition, compressing observations before storage.
- Parameters:
obs – Current observation batch.
next_obs – Next observation batch.
action – Action batch.
reward – Reward batch.
done – Episode termination flags.
infos – Per-environment info dicts from the vectorized environment.
- class CompressedDictReplayBuffer(buffer_size: int, observation_space: Dict, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw')
Bases:
CompressedReplayBufferDictReplayBuffer, but compressed!
Create a compressed replay buffer for dictionary observations.
- Parameters:
buffer_size – Maximum number of transitions per environment.
observation_space – Gymnasium
Dictobservation space.action_space – Gymnasium action space.
device – Torch device used when sampling batches.
n_envs – Number of parallel environments.
optimize_memory_usage – Must be
Falsefor dict observations.handle_timeout_termination – Store timeout flags from
TimeLimit.truncated.dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
output_dtype – Sample dtype for observations (
"raw"keeps storage dtype).
- add(obs: dict[str, ndarray], next_obs: dict[str, ndarray], action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None
Add a dict observation transition, compressing each key separately.
- Parameters:
obs – Current observation dict.
next_obs – Next observation dict.
action – Action batch.
reward – Reward batch.
done – Episode termination flags.
infos – Per-environment info dicts from the vectorized environment.
On-policy rollout buffers that store compressed observations.
- class CompressedRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None)
Bases:
RolloutBuffer,BaseCompressedBufferRolloutBuffer, but compressed!
Create a compressed rollout buffer.
- Parameters:
buffer_size – Number of steps collected per environment before rollout ends.
observation_space – Gymnasium observation space.
action_space – Gymnasium action space.
device – Torch device used when sampling batches.
gae_lambda – GAE lambda for advantage estimation.
gamma – Discount factor for returns.
n_envs – Number of parallel environments.
dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
- add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a rollout step with a compressed observation.
- Parameters:
obs – Observation batch from the environment.
action – Action batch.
reward – Reward batch.
episode_start – Whether each environment started a new episode.
value – Value estimate for the current state under the policy.
log_prob – Log probability of the action under the policy.
- class CompressedDictRolloutBuffer(buffer_size: int, observation_space: Dict, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None)
Bases:
CompressedRolloutBufferDictRolloutBuffer, but compressed!
Create a compressed rollout buffer for dictionary observations.
- Parameters:
buffer_size – Number of steps collected per environment before rollout ends.
observation_space – Gymnasium
Dictobservation space.action_space – Gymnasium action space.
device – Torch device used when sampling batches.
gae_lambda – GAE lambda for advantage estimation.
gamma – Discount factor for returns.
n_envs – Number of parallel environments.
dtypes – Element and run-length dtypes for compression.
normalize_images – Divide image observations by 255 when sampling.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
- add(obs: dict[str, ndarray], action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a dict rollout step with compressed observations per key.
- Parameters:
obs – Observation dict from the environment.
action – Action batch.
reward – Reward batch.
episode_start – Whether each environment started a new episode.
value – Value estimate for the current state under the policy.
log_prob – Log probability of the action under the policy.
- get(batch_size: int | None = None) Generator[DictRolloutBufferSamples, None, None]
Yield shuffled dict rollout minibatches after the buffer is full.
- Parameters:
batch_size – Minibatch size. When
None, the full flattened buffer is used.- Yields:
Batches of dict rollout samples with decompressed observations.
NumPy ndarray subclass that stores compressed observation bytes.
- class CompressedArray(shape: int | tuple | Any, dtype: integer | floating, obs_shape: int | tuple | Any, buffer: Any | None = None, offset: Any = 0, strides: Any | None = None, order: Literal[None, 'K', 'A', 'C', 'F'] = None, dtypes: dict | None = None, compression_method: str = 'rle', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, **kwargs)
Bases:
ndarray,BaseCompressedBufferExperimental Compressed Array Class.
Initialize compression settings for this array view.
- Parameters:
shape – Storage shape for compressed byte objects.
dtype – Element dtype of reconstructed observations.
obs_shape – Original observation shape before flattening.
buffer – Optional underlying buffer passed to
np.ndarray.offset – Byte offset into
buffer.strides – Stride tuple passed to
np.ndarray.order – Memory layout order passed to
np.ndarray.dtypes – Element and run-length dtypes; inferred when omitted.
compression_method – Registered compression method name.
compression_kwargs – Keyword arguments for compression.
decompression_kwargs – Keyword arguments for decompression.
**kwargs – Additional arguments forwarded to the ndarray base.
NumPy helpers for dtype selection and array reshaping.
- find_optimal_shape(arr_len: int, dtype: dtype = <class 'numpy.uint8'>) tuple[int, int, int]
Find a way to slice longer 1D arrays.
- find_smallest_dtype(max_val: int, signed: bool = False, fallback: dtype = <class 'numpy.float32'>) dtype
Find smallest dtype for runs_type.
Compression backends and availability probes.