Dataset Creation#

Here we provide code for the standard procedure for collecting a dataset.

Environment Setup#

To run the full code below, you will need to install the dependencies shown below. It is recommended to use a newly-created virtual environment to avoid dependency conflicts.

gymnasium-robotics
cython
#minari
git+https://github.com/WillDudley/Gymnasium.git@spec_stack#egg=gymnasium

Then you need to pip install -e . from the root of the repository.

Full Code#

# pyright: basic, reportOptionalMemberAccess=false

import base64
import json
import os

import gymnasium as gym
import numpy as np
from gymnasium.utils.serialize_spec_stack import serialise_spec_stack

import minari
from minari.dataset import MinariDataset

# 1. Get permissions to upload to GCP
GCP_DATASET_ADMIN = os.environ["GCP_DATASET_ADMIN"]

credentials_json = base64.b64decode(GCP_DATASET_ADMIN).decode("utf8").replace("'", '"')
with open("credentials.json", "w") as f:
    f.write(credentials_json)

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./credentials.json"


# 2. Standard Gymnasium procedure to collect data into whatever replay buffer you want
env = gym.make("FetchReach-v3")

environment_stack = serialise_spec_stack(
    env.spec_stack
)  # Get the environment specification stack for reproducibility

env.reset()
replay_buffer = {
    "episode": np.array([]),
    "observation": np.array([]),
    "action": np.array([]),
    "reward": np.array([]),
    "terminated": np.array([]),
    "truncated": np.array([]),
}
dataset_name = "FetchReach_v3_example-dataset"

num_episodes = 4


assert env.spec.max_episode_steps is not None, "Max episode steps must be defined"

replay_buffer = {
    "episode": np.array(
        [[0]] * env.spec.max_episode_steps * num_episodes, dtype=np.int32
    ),
    "observation": np.array(
        [[0] * 13] * env.spec.max_episode_steps * num_episodes,
        dtype=np.float32,
    ),
    "action": np.array(
        [[0] * 4] * env.spec.max_episode_steps * num_episodes, dtype=np.float32
    ),
    "reward": np.array(
        [[0]] * env.spec.max_episode_steps * num_episodes, dtype=np.float32
    ),
    "terminated": np.array(
        [[0]] * env.spec.max_episode_steps * num_episodes, dtype=bool
    ),
    "truncated": np.array(
        [[0]] * env.spec.max_episode_steps * num_episodes, dtype=bool
    ),
}

total_steps = 0
for episode in range(num_episodes):
    episode_step = 0
    observation, info = env.reset()
    terminated = False
    truncated = False
    while not terminated and not truncated:
        action = env.action_space.sample()  # User-defined policy function
        observation, reward, terminated, truncated, info = env.step(action)
        replay_buffer["episode"][total_steps] = np.array(episode)
        replay_buffer["observation"][total_steps] = np.concatenate(
            (
                np.array(observation["observation"]),
                np.array(observation["desired_goal"]),
            )
        )
        replay_buffer["action"][total_steps] = np.array(action)
        replay_buffer["reward"][total_steps] = np.array(reward)
        replay_buffer["terminated"][total_steps] = np.array(terminated)
        replay_buffer["truncated"][total_steps] = np.array(truncated)
        episode_step += 1
        total_steps += 1

env.close()

replay_buffer["episode"] = replay_buffer["episode"][:total_steps]
replay_buffer["observation"] = replay_buffer["observation"][:total_steps]
replay_buffer["action"] = replay_buffer["action"][:total_steps]
replay_buffer["reward"] = replay_buffer["reward"][:total_steps]
replay_buffer["terminated"] = replay_buffer["terminated"][:total_steps]
replay_buffer["truncated"] = replay_buffer["truncated"][:total_steps]


# 3. Convert the replay buffer to a MinariDataset
dataset = MinariDataset(
    dataset_name=dataset_name,
    algorithm_name="random_policy",
    environment_name="FetchReach-v3",
    environment_stack=json.dumps(environment_stack),
    seed_used=42,  # For the simplicity of this example, we're not actually using a seed. Naughty us!
    code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
    author="WillDudley",
    author_email="wdudley@farama.org",
    observations=replay_buffer["observation"],
    actions=replay_buffer["action"],
    rewards=replay_buffer["reward"],
    terminations=replay_buffer["terminated"],
    truncations=replay_buffer["truncated"],
)

print("Dataset generated!")

# 4. Save the dataset locally
dataset.save()

# 5. Upload the dataset to GCP
minari.upload_dataset(dataset_name)