GPU Buffers

Warning

Advanced / experimental. sb3_extra_buffers.gpu_buffers targets users who want observation tensors to live on a Torch device (CPU, CUDA, or MPS) with optional compression. The API is less stable than sb3_extra_buffers.compressed, behaviour varies by device and codec, and not every CPU compression backend has a GPU-native implementation yet.

Caution

You are responsible for heap sizing. Compressed observations are stored in a single packed byte heap (RawBuffer) with per-cell start_idx / lengths. If the heap runs out of space, compression raises an error. By default, buffers set heap capacity via estimate_total_heap_bytes() (per-cell heuristic × number of cells). For large or high-entropy observations, set heap_bytes explicitly on GpuReplayBuffer or GpuRolloutBuffer and validate with your observation distribution. The deprecated max_slot_bytes argument is interpreted as per-cell capacity and multiplied by the cell count.

Overview

GPU buffers mirror the CPU compressed replay/rollout API but keep flattened observations on buffer_device end-to-end (add → store → sample/get). Non-observation fields (actions, rewards, dones, etc.) remain NumPy arrays like the CPU buffers.

Public entry points

Compression methods

Example scripts (Pong, PongNoFrameskip-v4):

  • examples/example_train_gpu_replay.py / examples/example_watch_gpu_replay.py

  • examples/example_train_gpu_rollout.py / examples/example_watch_gpu_rollout.py

GPU-backed raw storage, compression, and SB3 buffers (experimental).

class GpuReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw', buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)

Bases: ReplayBuffer, BaseGpuBuffer

Replay buffer with observations stored on a Torch device.

Create a replay buffer with device-resident observations.

actions: ndarray
rewards: ndarray
dones: ndarray
timeouts: ndarray
add(obs: ndarray, next_obs: ndarray, action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None

Add a transition with device-resident observations.

reconstruct_obs(idx: int, env_idx: int) Tensor

Return the flattened observation at (idx, env_idx).

reconstruct_nextobs(idx: int, env_idx: int) Tensor

Return the flattened next observation at (idx, env_idx).

class GpuRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)

Bases: RolloutBuffer, BaseGpuBuffer

Rollout buffer with observations stored on a Torch device.

Create a rollout buffer with device-resident observations.

actions: ndarray
rewards: ndarray
advantages: ndarray
returns: ndarray
episode_starts: ndarray
log_probs: ndarray
values: ndarray
reset() None

Clear rollout storage and reset the write position.

add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None

Add a rollout step with a device-resident observation.

get(batch_size: int | None = None) Generator[RolloutBufferSamples, None, None]

Yield shuffled rollout minibatches after the buffer is full.

reconstruct_obs(idx: int) Tensor

Return the flattened observation at flattened index idx.

class RawBuffer(size: int, device: str | device = 'cpu')

Bases: object

Linear byte storage backed by torch.UntypedStorage.

Allocate size bytes on device.

Parameters:
  • size – Total storage size in bytes.

  • device – Torch device for the underlying storage.

write_bytes(malloc: tuple[int, int], tensor: Tensor)

Copy tensor bytes into the region described by malloc.

read_bytes(malloc: tuple[int, int], dtype: dtype)

View malloc bytes as a 1D tensor with dtype.

read_into(malloc: tuple[int, int], tensor: Tensor)

Copy bytes from malloc into the start of tensor.

copy_region(src_start: int, length: int, dst_start: int) None

Copy length bytes from src_start to dst_start.

class SharedRawHeap(n_cells: int, heap_bytes: int, device: str | device = 'cpu')

Bases: object

Shared byte heap with per-cell start_idx and lengths.

Allocate heap storage and per-cell index arrays.

compact() None

Pack all cell payloads into a contiguous prefix of the heap.

ensure_space(needed: int) None

Compact until at least needed bytes are free at data_end.

class SlotMetadata(byte_start: int, pos_runs: int, pos_elem: int, run_length: int, payload_bytes: int)

Bases: object

Compression metadata for one observation cell in the raw heap.

byte_start: int
pos_runs: int
pos_elem: int
run_length: int
payload_bytes: int
class BaseGpuBuffer(compression_method: str | None = None, compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, flatten_config: dict | None = None)

Bases: object

Base GPU buffer class wiring compression callables.

Configure compression and decompression callables.

find_gpu_buffer_dtypes(obs_shape: int | tuple, elem_dtype: dtype = torch.uint8, compression_method: str = 'rle') dict[str, Any]

Find Torch dtypes for GPU buffer compression.

has_zstd() bool

Return whether the Zstandard backend is available.

Metadata

Shared types for GPU buffer compression.

class SlotMetadata(byte_start: int, pos_runs: int, pos_elem: int, run_length: int, payload_bytes: int)

Bases: object

Compression metadata for one observation cell in the raw heap.

arr_config_length(arr_configs: dict) int

Return flattened observation length from size or shape keys.

Raw storage

Untyped GPU/CPU byte storage for packed observation heaps.

write_at(buffer: RawBuffer, byte_start: int, rel_byte_offset: int, tensor: Tensor) None

Write tensor bytes at byte_start + rel_byte_offset.

read_at(buffer: RawBuffer, byte_start: int, rel_byte_offset: int, elem_count: int, dtype: dtype) Tensor

Read elem_count elements at byte_start + rel_byte_offset.

class RawBuffer(size: int, device: str | device = 'cpu')

Bases: object

Linear byte storage backed by torch.UntypedStorage.

Allocate size bytes on device.

Parameters:
  • size – Total storage size in bytes.

  • device – Torch device for the underlying storage.

write_bytes(malloc: tuple[int, int], tensor: Tensor)

Copy tensor bytes into the region described by malloc.

read_bytes(malloc: tuple[int, int], dtype: dtype)

View malloc bytes as a 1D tensor with dtype.

read_into(malloc: tuple[int, int], tensor: Tensor)

Copy bytes from malloc into the start of tensor.

copy_region(src_start: int, length: int, dst_start: int) None

Copy length bytes from src_start to dst_start.

class SharedRawHeap(n_cells: int, heap_bytes: int, device: str | device = 'cpu')

Bases: object

Shared byte heap with per-cell start_idx and lengths.

Allocate heap storage and per-cell index arrays.

compact() None

Pack all cell payloads into a contiguous prefix of the heap.

ensure_space(needed: int) None

Compact until at least needed bytes are free at data_end.

Observation stores

Device-resident observation storage backends.

class DenseObservationStore(buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device)

Bases: object

Store flattened observations in a dense Torch tensor.

Allocate dense observation storage on device.

write(pos: int, env_idx: int, flat_tensor: Tensor)

Write a flattened observation.

read(pos: int, env_idx: int) Tensor

Read a flattened observation.

read_flat(flat_idx: int) Tensor

Read by flattened index after flatten().

flatten()

Flatten using the same ordering as SB3 swap_and_flatten.

class RawObservationStore(buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device, field_offset: int, compress: Callable[[...], SlotMetadata], decompress: Callable[[...], Tensor], compression_method: str = 'rle', runs_type: dtype = torch.uint16, shared_heap: SharedRawHeap | None = None, n_fields: int = 1)

Bases: object

Store compressed observations in a packed raw byte heap.

Create heap-backed observation storage.

cell_id(pos: int, env_idx: int) int

Map buffer coordinates to a global cell id.

write(pos: int, env_idx: int, flat_tensor: Tensor)

Compress and store a flattened observation.

read(pos: int, env_idx: int) Tensor

Decompress a flattened observation.

read_flat(flat_idx: int) Tensor

Read by flattened index after flatten().

flatten()

Flatten metadata to match rollout swap_and_flatten ordering.

compact() None

Pack the shared heap (no-op when this store does not own the heap).

create_observation_store(compression_method: str, buffer_size: int, n_envs: int, flat_len: int, elem_type: dtype, device: str | device, field_offset: int = 0, compress: Callable[[...], SlotMetadata] | None = None, decompress: Callable[[...], Tensor] | None = None, runs_type: dtype = torch.uint16, shared_heap: SharedRawHeap | None = None, n_fields: int = 1)

Create the observation store backend for compression_method.

Base helpers

Base classes and helpers for GPU observation storage.

find_gpu_buffer_dtypes(obs_shape: int | tuple, elem_dtype: dtype = torch.uint8, compression_method: str = 'rle') dict[str, Any]

Find Torch dtypes for GPU buffer compression.

class BaseGpuBuffer(compression_method: str | None = None, compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, flatten_config: dict | None = None)

Bases: object

Base GPU buffer class wiring compression callables.

Configure compression and decompression callables.

Replay buffer

Replay buffers with device-resident observations.

class GpuReplayBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, output_dtype: Literal['raw', 'float'] = 'raw', buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)

Bases: ReplayBuffer, BaseGpuBuffer

Replay buffer with observations stored on a Torch device.

Create a replay buffer with device-resident observations.

add(obs: ndarray, next_obs: ndarray, action: ndarray, reward: ndarray, done: ndarray, infos: list[dict[str, Any]]) None

Add a transition with device-resident observations.

reconstruct_obs(idx: int, env_idx: int) Tensor

Return the flattened observation at (idx, env_idx).

reconstruct_nextobs(idx: int, env_idx: int) Tensor

Return the flattened next observation at (idx, env_idx).

Rollout buffer

Rollout buffers with device-resident observations.

class GpuRolloutBuffer(buffer_size: int, observation_space: Space, action_space: Space, device: device | str = 'auto', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, dtypes: dict | None = None, normalize_images: bool = False, compression_method: str = 'none', compression_kwargs: dict | None = None, decompression_kwargs: dict | None = None, buffer_device: device | str | None = None, heap_bytes: int | None = None, max_slot_bytes: int | None = None)

Bases: RolloutBuffer, BaseGpuBuffer

Rollout buffer with observations stored on a Torch device.

Create a rollout buffer with device-resident observations.

reset() None

Clear rollout storage and reset the write position.

add(obs: ndarray, action: ndarray, reward: ndarray, episode_start: ndarray, value: Tensor, log_prob: Tensor) None

Add a rollout step with a device-resident observation.

get(batch_size: int | None = None) Generator[RolloutBufferSamples, None, None]

Yield shuffled rollout minibatches after the buffer is full.

reconstruct_obs(idx: int) Tensor

Return the flattened observation at flattened index idx.

Utilities

Torch dtype helpers for GPU buffer compression.

find_smallest_dtype(max_val: int, signed: bool = False, fallback: dtype = torch.float32) dtype

Find smallest dtype for runs_type.

torch_dtype_element_size(dtype: dtype) int

Return the element size in bytes for a Torch dtype.

estimate_max_slot_bytes(flat_len: int, elem_type: dtype, runs_type: dtype, compression_method: str, overalloc_factor: float = 1.5, compresslevel: int | None = None) int

Estimate per-cell byte capacity for packed heap storage.

Uses the larger of the MsPacman benchmark ratio and the legacy analytic bound so incompressible or noisy frames do not exceed the reserved span.

Parameters:
  • flat_len – Flattened observation length.

  • elem_type – Observation element dtype on the buffer device.

  • runs_type – Run-length dtype used by the legacy analytic bound.

  • compression_method – Codec name or shorthand (e.g. zstd3, zstd-5).

  • overalloc_factor – Headroom multiplier applied to the benchmark ratio.

  • compresslevel – If set, overrides any level parsed from compression_method.

estimate_total_heap_bytes(n_cells: int, flat_len: int, elem_type: dtype, runs_type: dtype, compression_method: str, overalloc_factor: float = 1.5, compresslevel: int | None = None) int

Estimate total heap capacity for n_cells packed observations.

numpy_dtype_to_torch(dtype: dtype | type | str) dtype

Convert a NumPy dtype to the matching Torch dtype.

torch_dtype_to_numpy(dtype: dtype) dtype

Convert a Torch dtype to the matching NumPy dtype.

Compression methods

GPU compression method registry.

GpuCompressionMethods

alias of GpuCompressionMethod

has_zstd() bool

Return whether the Zstandard backend is available.

Usage

Replay buffer with RLE on CUDA:

import torch as th
from stable_baselines3 import DQN
from sb3_extra_buffers.gpu_buffers import GpuReplayBuffer, find_gpu_buffer_dtypes
from sb3_extra_buffers.gpu_buffers.utils import numpy_dtype_to_torch

device = "cuda" if th.cuda.is_available() else "cpu"
dtypes = find_gpu_buffer_dtypes(
    env.observation_space.shape,
    elem_dtype=numpy_dtype_to_torch(env.observation_space.dtype),
    compression_method="rle",
)
model = DQN(
    "CnnPolicy",
    env,
    replay_buffer_class=GpuReplayBuffer,
    replay_buffer_kwargs=dict(
        dtypes=dtypes,
        compression_method="rle",
        buffer_device=device,
        # heap_bytes=...,  # override if estimate_total_heap_bytes is too small
    ),
    device=device,
)

Rollout buffer with optional Zstd (when installed):

from sb3_extra_buffers.gpu_buffers import GpuRolloutBuffer, has_zstd

compression_method = "zstd" if has_zstd() else "rle"
dtypes = find_gpu_buffer_dtypes(obs_shape, compression_method=compression_method)
model = PPO(
    "CnnPolicy",
    env,
    rollout_buffer_class=GpuRolloutBuffer,
    rollout_buffer_kwargs=dict(
        dtypes=dtypes,
        compression_method=compression_method,
        buffer_device=device,
    ),
    device=device,
)

Heap layout and compaction

Compressed observations share one RawBuffer backed by SharedRawHeap. Each cell has a start_idx and lengths entry; codec metadata (pos_runs, pos_elem, run_length) is stored relative to that cell’s byte offset.

Compaction packs live payloads and resets data_end to roughly sum(lengths):

  • GpuReplayBuffer compacts when the write cursor wraps after the buffer is full.

  • GpuRolloutBuffer compacts when the rollout buffer becomes full (before get()).

Default heap capacity is n_cells * estimate_max_slot_bytes(...) where n_cells depends on buffer geometry (replay vs rollout, n_envs, shared next-obs storage). Per-cell size uses MsPacman Save Mem % ratios from the README benchmark, fitted per codec family and level (see scripts/fit_heap_heuristic.py), scaled by flat_len * elem_size * overalloc_factor (default 1.5). All documented levels are accepted; unmeasured levels use polynomial extrapolation clamped to each family’s benchmark min/max ratio.

Supported level ranges (invalid levels raise ValueError):

  • gzip: 0–9

  • igzip: 0–3

  • zstd: 1–22 or -100–-1 (bare zstd uses level -3)

  • lz4-frame: negatives and 0–16

  • lz4-block: negatives, 0, and 1–12

Re-fit after updating the benchmark table:

uv pip install scipy
uv run scripts/fit_heap_heuristic.py

Copy the printed constants into sb3_extra_buffers.gpu_buffers.size_estimation.

After compaction, data_end reflects actual compressed usage; peak allocation remains heap_bytes.

If compression fails with a heap capacity error, increase heap_bytes. When in doubt, log compressed sizes from your environment and add headroom.