Source code for minari.dataset.minari_dataset

from __future__ import annotations

import os
import re
from dataclasses import dataclass, field
from typing import Callable, Iterable, Iterator, List, Optional, Union

import gymnasium as gym
import numpy as np
from gymnasium import error
from gymnasium.envs.registration import EnvSpec

from minari.data_collector import DataCollectorV0
from minari.dataset.minari_storage import MinariStorage, PathLike


DATASET_ID_RE = re.compile(
    r"(?:(?P<environment>[\w]+?))?(?:-(?P<dataset>[\w:.-]+?))(?:-v(?P<version>\d+))?$"
)


def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]:
    """Parse dataset ID string format - ``(env_name-)(dataset_name)(-v(version))``.

    Args:
        dataset_id: The dataset id to parse
    Returns:
        A tuple of environment name, dataset name and version number
    Raises:
        Error: If the dataset id is not valid dataset regex
    """
    match = DATASET_ID_RE.fullmatch(dataset_id)
    if not match:
        raise error.Error(
            f"Malformed dataset ID: {dataset_id}. (Currently all IDs must be of the form (env_name-)(dataset_name)-v(version). (namespace is optional))"
        )
    env_name, dataset_name, version = match.group("environment", "dataset", "version")

    version = int(version)

    return env_name, dataset_name, version


[docs] @dataclass(frozen=True) class EpisodeData: """Contains the datasets data for a single episode. This is the object returned by :class:`minari.MinariDataset.sample_episodes`. """ id: int seed: Optional[int] total_timesteps: int observations: np.ndarray actions: np.ndarray rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray def __repr__(self) -> str: return ( "EpisodeData(" f"id={repr(self.id)}, " f"seed={repr(self.seed)}, " f"total_timesteps={self.total_timesteps}, " f"observations={EpisodeData._repr_space_values(self.observations)}, " f"actions={EpisodeData._repr_space_values(self.actions)}, " f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " f"truncations=ndarray of {len(self.truncations)} bools" ")" ) @staticmethod def _repr_space_values(value): if isinstance(value, np.ndarray): return f"ndarray of shape {value.shape} and dtype {value.dtype}" elif isinstance(value, dict): reprs = [ f"{k}: {EpisodeData._repr_space_values(v)}" for k, v in value.items() ] dict_repr = ", ".join(reprs) return "{" + dict_repr + "}" elif isinstance(value, tuple): reprs = [EpisodeData._repr_space_values(v) for v in value] values_repr = ", ".join(reprs) return "(" + values_repr + ")" else: return repr(value)
@dataclass class MinariDatasetSpec: env_spec: EnvSpec total_episodes: int total_steps: int dataset_id: str combined_datasets: List[str] observation_space: gym.Space action_space: gym.Space data_path: str minari_version: str # post-init attributes env_name: str | None = field(init=False) dataset_name: str = field(init=False) version: int | None = field(init=False) def __post_init__(self): """Calls after the spec is created to extract the environment name, dataset name and version from the dataset id.""" self.env_name, self.dataset_name, self.version = parse_dataset_id( self.dataset_id )
[docs] class MinariDataset: """Main Minari dataset class to sample data and get metadata information from a dataset.""" def __init__( self, data: Union[MinariStorage, PathLike], episode_indices: Optional[np.ndarray] = None, ): """Initialize properties of the Minari Dataset. Args: data (Union[MinariStorage, _PathLike]): source of data. episode_indices (Optiona[np.ndarray]): slice of episode indices this dataset is pointing to. """ if isinstance(data, MinariStorage): self._data = data elif ( isinstance(data, str) or isinstance(data, os.PathLike) or isinstance(data, bytes) ): self._data = MinariStorage(data) else: raise ValueError(f"Unrecognized type {type(data)} for data") self._additional_data_id = 0 if episode_indices is None: episode_indices = np.arange(self._data.total_episodes) total_steps = self._data.total_steps else: total_steps = sum( self._data.apply( lambda episode: episode["total_timesteps"], episode_indices=episode_indices, ) ) self._episode_indices = episode_indices assert self._episode_indices is not None self.spec = MinariDatasetSpec( env_spec=self._data.env_spec, total_episodes=self._episode_indices.size, total_steps=total_steps, dataset_id=self._data.id, combined_datasets=self._data.combined_datasets, observation_space=self._data.observation_space, action_space=self._data.action_space, data_path=str(self._data.data_path), minari_version=str(self._data.minari_version), ) self._total_steps = total_steps self._generator = np.random.default_rng() @property def total_episodes(self): """Total episodes recorded in the Minari dataset.""" assert self._episode_indices is not None return self._episode_indices.size @property def total_steps(self): """Total episodes steps in the Minari dataset.""" return self._total_steps @property def episode_indices(self) -> np.ndarray: """Indices of the available episodes to sample within the Minari dataset.""" return self._episode_indices def recover_environment(self) -> gym.Env: """Recover the Gymnasium environment used to create the dataset. Returns: environment: Gymnasium environment """ return gym.make(self._data.env_spec) def set_seed(self, seed: int): """Set seed for random episode sampling generator.""" self._generator = np.random.default_rng(seed) def filter_episodes( self, condition: Callable[[EpisodeData], bool] ) -> MinariDataset: """Filter the dataset episodes with a condition. The condition must be a callable which takes an `EpisodeData` instance and retutrns a bool. The callable must return a `bool` True if the condition is met and False otherwise. i.e filtering for episodes that terminate: ``` dataset.filter(condition=lambda x: x['terminations'][-1] ) ``` Args: condition (Callable[[EpisodeData], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met. """ def dict_to_episode_data_condition(episode: dict) -> bool: return condition(EpisodeData(**episode)) mask = self._data.apply( dict_to_episode_data_condition, episode_indices=self._episode_indices ) assert self._episode_indices is not None return MinariDataset(self._data, episode_indices=self._episode_indices[mask]) def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]: """Sample n number of episodes from the dataset. Args: n_episodes (Optional[int], optional): number of episodes to sample. """ indices = self._generator.choice( self.episode_indices, size=n_episodes, replace=False ) episodes = self._data.get_episodes(indices) return list(map(lambda data: EpisodeData(**data), episodes)) def iterate_episodes( self, episode_indices: Optional[List[int]] = None ) -> Iterator[EpisodeData]: """Iterate over episodes from the dataset. Args: episode_indices (Optional[List[int]], optional): episode indices to iterate over. """ if episode_indices is None: assert self.episode_indices is not None assert self.episode_indices.ndim == 1 episode_indices = self.episode_indices.tolist() assert episode_indices is not None for episode_index in episode_indices: data = self._data.get_episodes([episode_index])[0] yield EpisodeData(**data) def update_dataset_from_collector_env(self, collector_env: DataCollectorV0): """Add extra data to Minari dataset from collector environment buffers (DataCollectorV0). This method can be used as a checkpoint when creating a dataset. A new HDF5 file will be created with the new dataset file in the same directory as `main_data.hdf5` called `additional_data_i.hdf5`. Both datasets are joined together by creating external links to each additional episode group: https://docs.h5py.org/en/stable/high/group.html#external-links Args: collector_env (DataCollectorV0): Collector environment """ # check that collector env has the same characteristics as self._env_spec new_data_file_path = os.path.join( os.path.split(self.spec.data_path)[0], f"additional_data_{self._additional_data_id}.hdf5", ) old_total_episodes = self._data.total_episodes self._data.update_from_collector_env( collector_env, new_data_file_path, self._additional_data_id ) new_total_episodes = self._data._total_episodes self._additional_data_id += 1 self._episode_indices = np.append( self._episode_indices, np.arange(old_total_episodes, new_total_episodes) ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) self.spec.total_episodes = self._episode_indices.size self.spec.total_steps = sum( self._data.apply( lambda episode: episode["total_timesteps"], episode_indices=self._episode_indices, ) ) def update_dataset_from_buffer(self, buffer: List[dict]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. Each episode dictionary buffer must have the following items: * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. Args: buffer (list[dict]): list of episode dictionary buffers to add to dataset """ old_total_episodes = self._data.total_episodes self._data.update_from_buffer(buffer, self.spec.data_path) new_total_episodes = self._data._total_episodes self._episode_indices = np.append( self._episode_indices, np.arange(old_total_episodes, new_total_episodes) ) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) self.spec.total_episodes = self._episode_indices.size self.spec.total_steps = sum( self._data.apply( lambda episode: episode["total_timesteps"], episode_indices=self._episode_indices, ) ) def __iter__(self): return self.iterate_episodes() def __getitem__(self, idx: int) -> EpisodeData: episodes_data = self._data.get_episodes([self.episode_indices[idx]]) assert len(episodes_data) == 1 return EpisodeData(**episodes_data[0]) def __len__(self) -> int: return self.total_episodes