from __future__ import annotations
import importlib.metadata
import os
import re
from dataclasses import dataclass, field
from typing import Callable, Iterable, Iterator, List
import gymnasium as gym
import numpy as np
import numpy.typing as npt
from gymnasium import error, logger
from gymnasium.envs.registration import EnvSpec
from packaging.requirements import InvalidRequirement, Requirement
from packaging.version import Version
from minari.data_collector.episode_buffer import EpisodeBuffer
from minari.dataset.episode_data import EpisodeData
from minari.dataset.minari_storage import MinariStorage, PathLike
VERSION_RE = r"(?:-v(?P<version>\d+))"
DATASET_NAME_RE = r"(?:(?P<dataset>[-_\w]+?))"
NAMESPACE_RE = r"(?:(?P<namespace>[-_\w][-_\w/]*[-_\w]+)\/)"
DATASET_ID_RE = re.compile(rf"^{NAMESPACE_RE}?{DATASET_NAME_RE}{VERSION_RE}?$")
def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int]:
"""Parse dataset ID string format - ``(namespace/)dataset_name(-v[version])``.
Args:
dataset_id (str): The dataset id to parse
Returns:
A tuple of namespace, dataset name and version number
Raises:
Error: If the dataset id is not valid dataset regex
"""
match = DATASET_ID_RE.fullmatch(dataset_id)
if not match:
raise error.Error(
f"Malformed dataset ID: {dataset_id}. (IDs must be of the form (namespace/)(dataset_name)-v(version). The namespace is optional.)"
)
namespace, dataset_name, version = match.group("namespace", "dataset", "version")
version = int(version)
if namespace == "":
namespace = None
return namespace, dataset_name, version
def gen_dataset_id(
namespace: str | None,
dataset_name: str,
version: int | None = None,
) -> str:
"""Generate a dataset ID from dataset attributes. Inverse of parse_dataset_id().
Args:
namespace (str | None): name of dataset subdir. Defaults to None.
dataset_name (str): name of the dataset.
version (int | None, optional): Dataset version. Defaults to None, in which case
the version tag will be suppressed.
Returns:
str: A dataset id string of the form ``(namespace/)(dataset_name)(-v(version))``.
The ``namespace`` and ``-v(version)`` are optional.
"""
namespace_str = f"{namespace}/" if namespace is not None else ""
version_str = f"-v{version}" if version is not None else ""
return f"{namespace_str}{dataset_name}{version_str}"
@dataclass
class MinariDatasetSpec:
env_spec: EnvSpec | None
total_episodes: int
total_steps: int
dataset_id: str
combined_datasets: List[str]
observation_space: gym.Space
action_space: gym.Space
data_path: str
minari_version: str
# post-init attributes
namespace: str | None = field(init=False)
dataset_name: str = field(init=False)
version: int | None = field(init=False)
def __post_init__(self):
"""Calls after the spec is created to extract the environment name, dataset name and version from the dataset id."""
(
self.namespace,
self.dataset_name,
self.version,
) = parse_dataset_id(self.dataset_id)
[docs]
class MinariDataset:
"""Main Minari dataset class to sample data and get metadata information from a dataset."""
def __init__(
self,
data: MinariStorage | PathLike,
episode_indices: npt.NDArray[np.int_] | None = None,
):
"""Initialize properties of the Minari Dataset.
Args:
data (Union[MinariStorage, PathLike]): source of data.
episode_indices (Optional[np.ndarray]): slice of episode indices this dataset is pointing to.
"""
if isinstance(data, MinariStorage):
self._data = data
elif isinstance(data, (str, os.PathLike)):
self._data = MinariStorage.read(data)
else:
raise ValueError(f"Unrecognized type {type(data)} for data")
self._total_steps = None
if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)
self._total_steps = self._data.total_steps
assert episode_indices is not None
self._episode_indices: npt.NDArray[np.int_] = episode_indices
metadata = self._data.metadata
env_spec = metadata.get("env_spec")
if env_spec is not None:
assert isinstance(env_spec, str)
env_spec = ( # for gymnasium 1.0.0 compatibility
env_spec.replace('"order_enforce": true,', "")
.replace('"apply_api_compatibility": false,', "")
.replace('"autoreset": false, ', "")
)
env_spec = EnvSpec.from_json(env_spec)
self._env_spec = env_spec
eval_env_spec = metadata.get("eval_env_spec")
if eval_env_spec is not None:
assert isinstance(eval_env_spec, str)
eval_env_spec = ( # for gymnasium 1.0.0 compatibility
eval_env_spec.replace('"order_enforce": true,', "")
.replace('"apply_api_compatibility": false,', "")
.replace('"autoreset": false, ', "")
)
eval_env_spec = EnvSpec.from_json(eval_env_spec)
self._eval_env_spec = eval_env_spec
dataset_id = metadata["dataset_id"]
assert isinstance(dataset_id, str)
self._dataset_id = dataset_id
minari_version = metadata["minari_version"]
assert isinstance(minari_version, str)
from minari import __version__, supported_dataset_versions
if minari_version not in supported_dataset_versions:
raise ValueError(
f"The installed Minari version {__version__} does not support the dataset generated by Minari {minari_version}."
f"Supported versions: {supported_dataset_versions}"
)
self._minari_version = minari_version
self._combined_datasets = metadata.get("combined_datasets", [])
self._observation_space = metadata["observation_space"]
self._action_space = metadata["action_space"]
assert isinstance(self._observation_space, gym.spaces.Space)
assert isinstance(self._action_space, gym.spaces.Space)
self._generator = np.random.default_rng()
def recover_environment(self, eval_env: bool = False, **kwargs) -> gym.Env:
"""Recover the Gymnasium environment used to create the dataset.
Args:
eval_env (bool): if True, the returned Gymnasium environment will be that intended to be used for evaluation. If no eval_env was specified when creating the dataset, the returned environment will be the same as the one used for creating the dataset. Default False.
**kwargs: any other parameter that you want to pass to the `gym.make` function.
Returns:
environment: Gymnasium environment
"""
requirements = self._data.metadata.get("requirements", [])
for req_str in requirements:
try:
req = Requirement(req_str)
except InvalidRequirement:
logger.warn(f"Ignoring malformed requirement `{req_str}`")
continue
try:
installed_version = Version(importlib.metadata.version(req.name))
except importlib.metadata.PackageNotFoundError:
logger.warn(
f'Package {req.name} is not installed. Install it with `pip install "{req_str}"`'
)
else:
if not req.specifier.contains(installed_version):
logger.warn(
f"Installed {req.name} version {installed_version} does not meet the requirement {req.specifier}.\n"
f'We recommend to install the required version with `pip install "{req_str}"`'
)
if eval_env:
if self._eval_env_spec is not None:
return gym.make(self._eval_env_spec, **kwargs)
logger.info(
f"`eval_env` has been set to True but the dataset {self._dataset_id} doesn't provide an evaluation environment. Instead, the environment used for collecting the data will be returned: {self._env_spec}"
)
if self.env_spec is None:
raise ValueError("Environment cannot be recovered when env_spec is None")
return gym.make(self.env_spec, **kwargs)
def set_seed(self, seed: int):
"""Set seed for random episode sampling generator."""
self._generator = np.random.default_rng(seed)
def filter_episodes(
self, condition: Callable[[EpisodeData], bool]
) -> MinariDataset:
"""Filter the dataset episodes with a condition.
The condition must be a callable which takes an `EpisodeData` instance and returns a bool.
The callable must return a `bool` True if the condition is met and False otherwise.
i.e filtering for episodes that terminate:
```
dataset.filter(condition=lambda x: x['terminations'][-1] )
```
Args:
condition (Callable[[EpisodeData], bool]): function that gets in input an EpisodeData object and returns True if certain condition is met.
"""
def dict_to_episode_data_condition(episode: dict) -> bool:
return condition(EpisodeData(**episode))
mask = self.storage.apply(
dict_to_episode_data_condition, episode_indices=self.episode_indices
)
assert self.episode_indices is not None
filtered_indices = self.episode_indices[list(mask)]
return MinariDataset(self.storage, episode_indices=filtered_indices)
def sample_episodes(self, n_episodes: int) -> Iterable[EpisodeData]:
"""Sample n number of episodes from the dataset.
Args:
n_episodes (Optional[int], optional): number of episodes to sample.
"""
indices = self._generator.choice(
self.episode_indices, size=n_episodes, replace=False
)
episodes = self.storage.get_episodes(indices)
return list(map(lambda data: EpisodeData(**data), episodes))
def iterate_episodes(
self, episode_indices: Iterable[int] | None = None
) -> Iterator[EpisodeData]:
"""Iterate over episodes from the dataset.
Args:
episode_indices (Optional[Iterable[int]], optional): episode indices to iterate over.
"""
if episode_indices is None:
assert self.episode_indices is not None
assert self.episode_indices.ndim == 1
episode_indices = self.episode_indices
assert episode_indices is not None
episodes_data = self.storage.get_episodes(episode_indices)
return map(lambda data: EpisodeData(**data), episodes_data)
def update_dataset_from_buffer(self, buffer: List[EpisodeBuffer]):
"""Additional data can be added to the Minari Dataset from a list of episode dictionary buffers.
Args:
buffer (list[EpisodeBuffer]): list of episode dictionary buffers to add to dataset
"""
first_id = self.storage.total_episodes
self.storage.update_episodes(buffer)
self.episode_indices = np.append(
self.episode_indices, first_id + np.arange(len(buffer))
)
def __iter__(self):
return self.iterate_episodes()
def __getitem__(self, idx: int) -> EpisodeData:
episode = self.iterate_episodes([self.episode_indices[idx]])
return next(episode)
def __len__(self) -> int:
return self.total_episodes
@property
def total_episodes(self) -> int:
"""Total number of episodes in the Minari dataset."""
return len(self.episode_indices)
@property
def total_steps(self) -> int:
"""Total episodes steps in the Minari dataset."""
if self._total_steps is None:
self._total_steps = 0
metadatas = self.storage.get_episode_metadata(self.episode_indices)
for m in metadatas:
self._total_steps += m["total_steps"]
return int(self._total_steps)
@property
def episode_indices(self) -> npt.NDArray[np.int_]:
"""Indices of the available episodes to sample within the Minari dataset."""
return self._episode_indices
@episode_indices.setter
def episode_indices(self, new_value: npt.NDArray[np.int_]):
self._total_steps = None # invalidate cache
self._episode_indices = new_value
@property
def observation_space(self):
"""Original observation space of the environment before flatteining (if this is the case)."""
return self._observation_space
@property
def action_space(self):
"""Original action space of the environment before flatteining (if this is the case)."""
return self._action_space
@property
def env_spec(self):
"""Envspec of the environment that has generated the dataset."""
return self._env_spec
@property
def combined_datasets(self) -> List[str]:
"""If this Minari dataset is a combination of other subdatasets, return a list with the subdataset names."""
if self._combined_datasets is None:
return []
else:
return self._combined_datasets
@property
def id(self) -> str:
"""Name of the Minari dataset."""
return self._dataset_id
@property
def minari_version(self) -> str:
"""Version of Minari the dataset is compatible with."""
return self._minari_version
@property
def storage(self) -> MinariStorage:
"""Minari storage managing access to disk."""
return self._data
@property
def spec(self) -> MinariDatasetSpec:
"""Minari dataset specifier."""
return MinariDatasetSpec(
env_spec=self.env_spec,
total_episodes=self._episode_indices.size,
total_steps=self.total_steps,
dataset_id=self.id,
combined_datasets=self.combined_datasets,
observation_space=self.observation_space,
action_space=self.action_space,
data_path=str(self.storage.data_path),
minari_version=str(self.minari_version),
)