GPU Buffers
Warning
Advanced / experimental. sb3_extra_buffers.gpu_buffers targets users who
want observation tensors to live on a Torch device (CPU, CUDA, or MPS) with optional
compression. The API is less stable than sb3_extra_buffers.compressed, behaviour
varies by device and codec, and not every CPU compression backend has a GPU-native
implementation yet.
Caution
You are responsible for heap sizing. Compressed observations are stored in a single
packed byte heap (RawBuffer) with
per-cell start_idx / lengths. If the heap runs out of space, compression raises
an error. By default, buffers set heap capacity via
estimate_total_heap_bytes() (per-cell
heuristic × number of cells). For large or high-entropy observations, set
heap_bytes explicitly on GpuReplayBuffer
or GpuRolloutBuffer and validate with
your observation distribution. The deprecated max_slot_bytes argument is interpreted
as per-cell capacity and multiplied by the cell count.
Overview
GPU buffers mirror the CPU compressed replay/rollout API but keep flattened observations on
buffer_device end-to-end (add → store → sample/get). Non-observation fields (actions,
rewards, dones, etc.) remain NumPy arrays like the CPU buffers.
Public entry points
GpuReplayBuffer— off-policy (DQN, SAC, …)GpuRolloutBuffer— on-policy (PPO, …)find_gpu_buffer_dtypes()— Torchelem_type/runs_typehelpersCOMPRESSION_METHOD_MAP— registered codecshas_zstd()— optional Zstd backend probe
Compression methods
Example scripts (Pong, PongNoFrameskip-v4):
examples/example_train_gpu_replay.py/examples/example_watch_gpu_replay.pyexamples/example_train_gpu_rollout.py/examples/example_watch_gpu_rollout.py
GPU-backed raw storage, compression, and SB3 buffers (experimental).
- class GpuReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw', buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)
Bases:
ReplayBuffer,BaseGpuBufferReplay buffer with observations stored on a Torch device.
Create a replay buffer with device-resident observations.
- class GpuRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)
Bases:
RolloutBuffer,BaseGpuBufferRollout buffer with observations stored on a Torch device.
Create a rollout buffer with device-resident observations.
- add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a rollout step with a device-resident observation.
- class RawBuffer(size: int, device: str | device = 'cpu')
Bases:
objectLinear byte storage backed by
torch.UntypedStorage.Allocate
sizebytes ondevice.- Parameters:
size – Total storage size in bytes.
device – Torch device for the underlying storage.
- write_bytes(malloc: tuple[int, int], tensor: Tensor)
Copy tensor bytes into the region described by
malloc.
Bases:
objectShared byte heap with per-cell
start_idxandlengths.Allocate heap storage and per-cell index arrays.
Pack all cell payloads into a contiguous prefix of the heap.
Compact until at least
neededbytes are free atdata_end.
- class SlotMetadata(byte_start: int, pos_runs: int, pos_elem: int, run_length: int, payload_bytes: int)
Bases:
objectCompression metadata for one observation cell in the raw heap.
- class BaseGpuBuffer(compression_method: str | None = None, compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, flatten_config: dict | None = None)
Bases:
objectBase GPU buffer class wiring compression callables.
Configure compression and decompression callables.
Metadata
Shared types for GPU buffer compression.
Raw storage
Untyped GPU/CPU byte storage for packed observation heaps.
- write_at(buffer: RawBuffer, byte_start: int, rel_byte_offset: int, tensor: Tensor) None
Write
tensorbytes atbyte_start + rel_byte_offset.
- read_at(buffer: RawBuffer, byte_start: int, rel_byte_offset: int, elem_count: int, dtype: dtype) Tensor
Read
elem_countelements atbyte_start + rel_byte_offset.
- class RawBuffer(size: int, device: str | device = 'cpu')
Bases:
objectLinear byte storage backed by
torch.UntypedStorage.Allocate
sizebytes ondevice.- Parameters:
size – Total storage size in bytes.
device – Torch device for the underlying storage.
- write_bytes(malloc: tuple[int, int], tensor: Tensor)
Copy tensor bytes into the region described by
malloc.
Bases:
objectShared byte heap with per-cell
start_idxandlengths.Allocate heap storage and per-cell index arrays.
Pack all cell payloads into a contiguous prefix of the heap.
Compact until at least
neededbytes are free atdata_end.
Observation stores
Device-resident observation storage backends.
- class DenseObservationStore(buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device)
Bases:
objectStore flattened observations in a dense Torch tensor.
Allocate dense observation storage on
device.- flatten()
Flatten using the same ordering as SB3
swap_and_flatten.
- class RawObservationStore(buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device, field_offset: int, compress: Callable[[...], SlotMetadata], decompress: Callable[[...], Tensor], compression_method: str = 'rle', runs_type: dtype = torch.uint16, shared_heap: SharedRawHeap | None = None, n_fields: int = 1)
Bases:
objectStore compressed observations in a packed raw byte heap.
Create heap-backed observation storage.
- flatten()
Flatten metadata to match rollout
swap_and_flattenordering.
- create_observation_store(compression_method: str, buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device, field_offset: int = 0, compress: Callable[[...], SlotMetadata] | None = None, decompress: Callable[[...], Tensor] | None = None, runs_type: dtype = torch.uint16, shared_heap: SharedRawHeap | None = None, n_fields: int = 1)
Create the observation store backend for
compression_method.
Base helpers
Base classes and helpers for GPU observation storage.
Replay buffer
Replay buffers with device-resident observations.
- class GpuReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw', buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)
Bases:
ReplayBuffer,BaseGpuBufferReplay buffer with observations stored on a Torch device.
Create a replay buffer with device-resident observations.
Rollout buffer
Rollout buffers with device-resident observations.
- class GpuRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)
Bases:
RolloutBuffer,BaseGpuBufferRollout buffer with observations stored on a Torch device.
Create a rollout buffer with device-resident observations.
- add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None
Add a rollout step with a device-resident observation.
Utilities
Torch dtype helpers for GPU buffer compression.
- find_smallest_dtype(max_val: int, signed: bool = False, fallback: dtype = torch.float32) dtype
Find smallest dtype for runs_type.
- estimate_max_slot_bytes(flat_len: int, elem_type: dtype, runs_type: dtype, compression_method: str, overalloc_factor: float = 1.5, compresslevel: int | None = None) int
Estimate per-cell byte capacity for packed heap storage.
Uses the larger of the MsPacman benchmark ratio and the legacy analytic bound so incompressible or noisy frames do not exceed the reserved span.
- Parameters:
flat_len – Flattened observation length.
elem_type – Observation element dtype on the buffer device.
runs_type – Run-length dtype used by the legacy analytic bound.
compression_method – Codec name or shorthand (e.g.
zstd3,zstd-5).overalloc_factor – Headroom multiplier applied to the benchmark ratio.
compresslevel – If set, overrides any level parsed from
compression_method.
- estimate_total_heap_bytes(n_cells: int, flat_len: int, elem_type: dtype, runs_type: dtype, compression_method: str, overalloc_factor: float = 1.5, compresslevel: int | None = None) int
Estimate total heap capacity for
n_cellspacked observations.
Compression methods
GPU compression method registry.
- GpuCompressionMethods
alias of
GpuCompressionMethod
Usage
Replay buffer with RLE on CUDA:
import torch as th
from stable_baselines3 import DQN
from sb3_extra_buffers.gpu_buffers import GpuReplayBuffer, find_gpu_buffer_dtypes
from sb3_extra_buffers.gpu_buffers.utils import numpy_dtype_to_torch
device = "cuda" if th.cuda.is_available() else "cpu"
dtypes = find_gpu_buffer_dtypes(
env.observation_space.shape,
elem_dtype=numpy_dtype_to_torch(env.observation_space.dtype),
compression_method="rle",
)
model = DQN(
"CnnPolicy",
env,
replay_buffer_class=GpuReplayBuffer,
replay_buffer_kwargs=dict(
dtypes=dtypes,
compression_method="rle",
buffer_device=device,
# heap_bytes=..., # override if estimate_total_heap_bytes is too small
),
device=device,
)
Rollout buffer with optional Zstd (when installed):
from sb3_extra_buffers.gpu_buffers import GpuRolloutBuffer, has_zstd
compression_method = "zstd" if has_zstd() else "rle"
dtypes = find_gpu_buffer_dtypes(obs_shape, compression_method=compression_method)
model = PPO(
"CnnPolicy",
env,
rollout_buffer_class=GpuRolloutBuffer,
rollout_buffer_kwargs=dict(
dtypes=dtypes,
compression_method=compression_method,
buffer_device=device,
),
device=device,
)
Heap layout and compaction
Compressed observations share one RawBuffer
backed by SharedRawHeap. Each cell has a
start_idx and lengths entry; codec metadata (pos_runs, pos_elem, run_length)
is stored relative to that cell’s byte offset.
Compaction packs live payloads and resets data_end to roughly sum(lengths):
GpuReplayBuffercompacts when the write cursor wraps after the buffer is full.GpuRolloutBuffercompacts when the rollout buffer becomes full (beforeget()).
Default heap capacity is n_cells * estimate_max_slot_bytes(...) where n_cells depends
on buffer geometry (replay vs rollout, n_envs, shared next-obs storage). Per-cell size uses
MsPacman Save Mem % ratios from the README benchmark, fitted per codec family and level
(see scripts/fit_heap_heuristic.py), scaled by flat_len * elem_size * overalloc_factor
(default 1.5). All documented levels are accepted; unmeasured levels use polynomial
extrapolation clamped to each family’s benchmark min/max ratio.
Supported level ranges (invalid levels raise ValueError):
gzip: 0–9igzip: 0–3zstd: 1–22 or -100–-1 (barezstduses level -3)lz4-frame: negatives and 0–16lz4-block: negatives, 0, and 1–12
Re-fit after updating the benchmark table:
uv pip install scipy
uv run scripts/fit_heap_heuristic.py
Copy the printed constants into sb3_extra_buffers.gpu_buffers.size_estimation.
After compaction, data_end reflects actual compressed usage; peak allocation remains heap_bytes.
If compression fails with a heap capacity error, increase heap_bytes. When in doubt,
log compressed sizes from your environment and add headroom.