Source code for minari.dataset.episode_data

from dataclasses import dataclass
from typing import Any

import numpy as np


[docs] @dataclass(frozen=True) class EpisodeData: """Contains the datasets data for a single episode.""" id: int observations: Any actions: Any rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray infos: dict def __len__(self) -> int: return len(self.rewards) def __repr__(self) -> str: return ( "EpisodeData(" f"id={self.id}, " f"total_steps={len(self)}, " f"observations={EpisodeData._repr_space_values(self.observations)}, " f"actions={EpisodeData._repr_space_values(self.actions)}, " f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " f"truncations=ndarray of {len(self.truncations)} bools, " f"infos=dict with the following keys: {list(self.infos.keys())}" ")" ) @staticmethod def _repr_space_values(value): if isinstance(value, np.ndarray): return f"ndarray of shape {value.shape} and dtype {value.dtype}" elif isinstance(value, dict): reprs = [ f"{k}: {EpisodeData._repr_space_values(v)}" for k, v in value.items() ] dict_repr = ", ".join(reprs) return "{" + dict_repr + "}" elif isinstance(value, tuple): reprs = [EpisodeData._repr_space_values(v) for v in value] values_repr = ", ".join(reprs) return "(" + values_repr + ")" else: return repr(value)