Source code for minari.data_collector.episode_buffer

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union

from minari.dataset.step_data import StepData


[docs] @dataclass(frozen=True) class EpisodeBuffer: """Contains the data of a single episode.""" id: Optional[int] = None seed: Optional[int] = None options: Optional[dict] = None observations: Union[None, list, dict, tuple] = None actions: Union[None, list, dict, tuple] = None rewards: list = field(default_factory=list) terminations: list = field(default_factory=list) truncations: list = field(default_factory=list) infos: Optional[dict] = None def add_step_data(self, step_data: StepData) -> EpisodeBuffer: """Add step data dictionary to episode buffer. Args: step_data (StepData): dictionary with data for a single step Returns: EpisodeBuffer: episode buffer with appended data """ try: import jax.tree_util as jtu except ImportError: raise ImportError( 'jax is not installed. Please install it using `pip install "minari[create]"`' ) def _append(data, buffer): if isinstance(buffer, list): buffer.append(data) return buffer else: return [buffer, data] observations = step_data["observation"] if self.observations is not None: observations = jtu.tree_map( _append, step_data["observation"], self.observations ) if self.actions is None: actions = jtu.tree_map(lambda x: [x], step_data["action"]) else: actions = jtu.tree_map(_append, step_data["action"], self.actions) if self.infos is None: infos = jtu.tree_map(lambda x: [x], step_data["info"]) else: infos = jtu.tree_map(_append, step_data["info"], self.infos) self.rewards.append(step_data["reward"]) self.terminations.append(step_data["terminated"]) self.truncations.append(step_data["truncated"]) return EpisodeBuffer( id=self.id, seed=self.seed, options=self.options, observations=observations, actions=actions, rewards=self.rewards, terminations=self.terminations, truncations=self.truncations, infos=infos, ) def __len__(self) -> int: """Buffer length.""" return len(self.rewards)