Source code for minari.dataset.episode_data
from dataclasses import dataclass
from typing import Any
import numpy as np
[docs]
@dataclass(frozen=True)
class EpisodeData:
"""Contains the datasets data for a single episode."""
id: int
observations: Any
actions: Any
rewards: np.ndarray
terminations: np.ndarray
truncations: np.ndarray
infos: dict
def __len__(self) -> int:
return len(self.rewards)
def __repr__(self) -> str:
return (
"EpisodeData("
f"id={self.id}, "
f"total_steps={len(self)}, "
f"observations={EpisodeData._repr_space_values(self.observations)}, "
f"actions={EpisodeData._repr_space_values(self.actions)}, "
f"rewards=ndarray of {len(self.rewards)} floats, "
f"terminations=ndarray of {len(self.terminations)} bools, "
f"truncations=ndarray of {len(self.truncations)} bools, "
f"infos=dict with the following keys: {list(self.infos.keys())}"
")"
)
@staticmethod
def _repr_space_values(value):
if isinstance(value, np.ndarray):
return f"ndarray of shape {value.shape} and dtype {value.dtype}"
elif isinstance(value, dict):
reprs = [
f"{k}: {EpisodeData._repr_space_values(v)}" for k, v in value.items()
]
dict_repr = ", ".join(reprs)
return "{" + dict_repr + "}"
elif isinstance(value, tuple):
reprs = [EpisodeData._repr_space_values(v) for v in value]
values_repr = ", ".join(reprs)
return "(" + values_repr + ")"
else:
return repr(value)