Source code for minari.dataset.minari_dataset

from __future__ import annotations

import os
import re
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Union

import gymnasium as gym
import h5py
import numpy as np
from gymnasium import error
from gymnasium.envs.registration import EnvSpec

from minari.data_collector import DataCollectorV0
from minari.dataset.minari_storage import MinariStorage, PathLike


DATASET_ID_RE = re.compile(
    r"(?:(?P<environment>[\w]+?))?(?:-(?P<dataset>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
)


def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int | None]:
    """Parse dataset ID string format - ``(env_name-)(dataset_name)(-v(version))``.

    Args:
        dataset_id: The dataset id to parse
    Returns:
        A tuple of environment name, dataset name and version number
    Raises:
        Error: If the dataset id is not valid dataset regex
    """
    match = DATASET_ID_RE.fullmatch(dataset_id)
    if not match:
        raise error.Error(
            f"Malformed dataset ID: {dataset_id}. (Currently all IDs must be of the form (env_name-)(dataset_name)-v(version). (namespace is optional))"
        )
    env_name, dataset_name, version = match.group("environment", "dataset", "version")
    if version is not None:
        version = int(version)

    return env_name, dataset_name, version


def clear_episode_buffer(episode_buffer: Dict, eps_group: h5py.Group) -> h5py.Group:
    """Save an episode dictionary buffer into an HDF5 episode group recursively.

    Args:
        episode_buffer (dict): episode buffer
        eps_group (h5py.Group): HDF5 group to store the episode datasets

    Returns:
        episode group: filled HDF5 episode group
    """
    for key, data in episode_buffer.items():
        if isinstance(data, dict):
            if key in eps_group:
                eps_group_to_clear = eps_group[key]
            else:
                eps_group_to_clear = eps_group.create_group(key)
            clear_episode_buffer(data, eps_group_to_clear)
        else:
            # assert data is numpy array
            assert np.all(np.logical_not(np.isnan(data)))
            # add seed to attributes
            eps_group.create_dataset(key, data=data, chunks=True)

    return eps_group


[docs]class EpisodeData(NamedTuple): """Contains the datasets data for a single episode. This is the object returned by :class:`minari.MinariDataset.sample_episodes`. """ id: int seed: Optional[int] total_timesteps: int observations: np.ndarray actions: np.ndarray rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray
@dataclass class MinariDatasetSpec: flatten_observations: bool flatten_actions: bool env_spec: EnvSpec total_episodes: int total_steps: int dataset_id: str combined_datasets: List[str] observation_space: gym.Space action_space: gym.Space data_path: str # post-init attributes env_name: str | None = field(init=False) dataset_name: str = field(init=False) version: int | None = field(init=False) def __post_init__(self): """Calls after the spec is created to extract the environment name, dataset name and version from the dataset id.""" self.env_name, self.dataset_name, self.version = parse_dataset_id( self.dataset_id )
[docs]class MinariDataset: """Main Minari dataset class to sample data and get metadata information from a dataset.""" def __init__( self, data: Union[MinariStorage, PathLike], episode_indices: Optional[np.ndarray] = None, ): """Initialize properties of the Minari Dataset. Args: data (Union[MinariStorage, _PathLike]): source of data. episode_indices (Optiona[np.ndarray]): slice of episode indices this dataset is pointing to. """ if isinstance(data, MinariStorage): self._data = data elif ( isinstance(data, str) or isinstance(data, os.PathLike) or isinstance(data, bytes) ): self._data = MinariStorage(data) else: raise ValueError(f"Unrecognized type {type(data)} for data") self._additional_data_id = 0 if episode_indices is None: episode_indices = np.arange(self._data.total_episodes) self._episode_indices = episode_indices self.spec = MinariDatasetSpec( flatten_observations=self._data.flatten_observations, flatten_actions=self._data.flatten_actions, env_spec=self._data.env_spec, total_episodes=self._data.total_episodes, total_steps=self._data.total_steps, dataset_id=self._data.id, combined_datasets=self._data.combined_datasets, observation_space=self._data.observation_space, action_space=self._data.action_space, data_path=str(self._data.data_path), ) self._total_steps = None self._generator = np.random.default_rng() @property def total_episodes(self): """Total episodes recorded in the Minari dataset.""" assert self._episode_indices is not None return len(self._episode_indices) @property def total_steps(self): """Total episodes steps in the Minari dataset.""" if self._total_steps is None: t_steps = self._data.apply( lambda episode: episode["total_steps"], episode_indices=self._episode_indices, ) self._total_steps = sum(t_steps) return self._total_steps @property def episode_indices(self): """Indices of the available episodes to sample within the Minari dataset.""" return self._episode_indices def recover_environment(self): """Recover the Gymnasium environment used to create the dataset. Returns: environment: Gymnasium environment """ return gym.make(self._data.env_spec) def set_seed(self, seed: int): """Set seed for random episode sampling generator.""" self._generator = np.random.default_rng(seed) def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariDataset: """Filter the dataset episodes with a condition. The condition must be a callable with a single argument, the episode HDF5 group. The callable must return a `bool` True if the condition is met and False otherwise. i.e filtering for episodes that terminate: ``` dataset.filter(condition=lambda x: x['terminations'][-1] ) ``` Args: condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met. """ mask = self._data.apply(condition, episode_indices=self._episode_indices) assert self._episode_indices is not None return MinariDataset(self._data, episode_indices=self._episode_indices[mask]) def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: """Sample n number of episodes from the dataset. Args: n_episodes (Optional[int], optional): number of episodes to sample. """ indices = self._generator.choice( self._episode_indices, size=n_episodes, replace=False ) episodes = self._data.get_episodes(indices) return list(map(lambda data: EpisodeData(**data), episodes)) def iterate_episodes( self, episode_indices: Optional[List[int]] = None ) -> Iterator[EpisodeData]: """Iterate over episodes from the dataset. Args: episode_indices (Optional[List[int]], optional): episode indices to iterate over. """ if episode_indices is None: assert self._episode_indices is not None assert self._episode_indices.ndim == 1 episode_indices = self._episode_indices.tolist() assert episode_indices is not None for episode_index in episode_indices: data = self._data.get_episodes([episode_index])[0] yield EpisodeData(**data) def update_dataset_from_collector_env(self, collector_env: DataCollectorV0): """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). This method can be used as a checkpoint when creating a dataset. A new HDF5 file will be created with the new dataset file in the same directory as `main_data.hdf5` called `additional_data_i.hdf5`. Both datasets are joined together by creating external links to each additional episode group: https://docs.h5py.org/en/stable/high/group.html#external-links Args: collector_env (DataCollectorV0): Collector environment """ # check that collector env has the same characteristics as self._env_spec new_data_file_path = os.path.join( os.path.split(self.spec.data_path)[0], f"additional_data_{self._additional_data_id}.hdf5", ) collector_env.save_to_disk(path=new_data_file_path) with h5py.File(new_data_file_path, "r", track_order=True) as new_data_file: new_data_total_episodes = new_data_file.attrs["total_episodes"] new_data_total_steps = new_data_file.attrs["total_steps"] with h5py.File(self._data.data_path, "a", track_order=True) as file: last_episode_id = file.attrs["total_episodes"] for id in range(new_data_total_episodes): file[f"episode_{last_episode_id + id}"] = h5py.ExternalLink( f"additional_data_{self._additional_data_id}.hdf5", f"/episode_{id}" ) file[f"episode_{last_episode_id + id}"].attrs.modify( "id", last_episode_id + id ) # Update metadata of minari dataset file.attrs.modify( "total_episodes", last_episode_id + new_data_total_episodes ) file.attrs.modify( "total_steps", file.attrs["total_steps"] + new_data_total_steps ) self._additional_data_id += 1 def update_dataset_from_buffer(self, buffer: List[dict]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. Each episode dictionary buffer must have the following items: * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. Args: buffer (list[dict]): list of episode dictionary buffers to add to dataset """ additional_steps = 0 with h5py.File(self.spec.data_path, "a", track_order=True) as file: last_episode_id = file.attrs["total_episodes"] for i, eps_buff in enumerate(buffer): episode_id = last_episode_id + i # check episode terminated or truncated assert ( eps_buff["terminations"][-1] or eps_buff["truncations"][-1] ), "Each episode must be terminated or truncated before adding it to a Minari dataset" assert len(eps_buff["actions"]) + 1 == len( eps_buff["observations"] ), f"Number of observations {len(eps_buff['observations'])} must have an additional \ element compared to the number of action steps {len(eps_buff['actions'])} \ The initial and final observation must be included" seed = eps_buff.pop("seed", None) eps_group = clear_episode_buffer( eps_buff, file.create_group(f"episode_{episode_id}") ) eps_group.attrs["id"] = episode_id total_steps = len(eps_buff["actions"]) eps_group.attrs["total_steps"] = total_steps additional_steps += total_steps if seed is None: eps_group.attrs["seed"] = str(None) else: assert isinstance(seed, int) eps_group.attrs["seed"] = seed # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata file.attrs.modify("total_episodes", last_episode_id + len(buffer)) file.attrs.modify( "total_steps", file.attrs["total_steps"] + additional_steps ) def __iter__(self): return self.iterate_episodes()