|
import einops |
|
import os |
|
import random |
|
from collections import deque |
|
from pathlib import Path |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from omegaconf import OmegaConf |
|
|
|
import wandb |
|
from utils.video import VideoRecorder |
|
import pickle |
|
from datasets.core import TrajectoryEmbeddingDataset, split_traj_datasets |
|
from datasets.vqbet_repro import TrajectorySlicerDataset |
|
|
|
|
|
if "MUJOCO_GL" not in os.environ: |
|
os.environ["MUJOCO_GL"] = "egl" |
|
|
|
|
|
def seed_everything(random_seed: int): |
|
np.random.seed(random_seed) |
|
torch.manual_seed(random_seed) |
|
torch.cuda.manual_seed_all(random_seed) |
|
random.seed(random_seed) |
|
|
|
|
|
@hydra.main(config_path="eval_configs", version_base="1.2") |
|
def main(cfg): |
|
print(OmegaConf.to_yaml(cfg)) |
|
seed_everything(cfg.seed) |
|
|
|
encoder = hydra.utils.instantiate(cfg.encoder) |
|
encoder = encoder.to(cfg.device).eval() |
|
|
|
dataset = hydra.utils.instantiate(cfg.dataset) |
|
train_data, test_data = split_traj_datasets( |
|
dataset, |
|
train_fraction=cfg.train_fraction, |
|
random_seed=cfg.seed, |
|
) |
|
use_libero_goal = cfg.data.get("use_libero_goal", False) |
|
train_data = TrajectoryEmbeddingDataset( |
|
encoder, train_data, device=cfg.device, embed_goal=use_libero_goal |
|
) |
|
test_data = TrajectoryEmbeddingDataset( |
|
encoder, test_data, device=cfg.device, embed_goal=use_libero_goal |
|
) |
|
traj_slicer_kwargs = { |
|
"window": cfg.data.window_size, |
|
"action_window": cfg.data.action_window_size, |
|
"vqbet_get_future_action_chunk": cfg.data.vqbet_get_future_action_chunk, |
|
"future_conditional": (cfg.data.goal_conditional == "future"), |
|
"min_future_sep": cfg.data.action_window_size, |
|
"future_seq_len": cfg.data.future_seq_len, |
|
"use_libero_goal": use_libero_goal, |
|
} |
|
train_data = TrajectorySlicerDataset(train_data, **traj_slicer_kwargs) |
|
test_data = TrajectorySlicerDataset(test_data, **traj_slicer_kwargs) |
|
train_loader = torch.utils.data.DataLoader( |
|
train_data, batch_size=cfg.batch_size, shuffle=True, pin_memory=False |
|
) |
|
test_loader = torch.utils.data.DataLoader( |
|
test_data, batch_size=cfg.batch_size, shuffle=False, pin_memory=False |
|
) |
|
for param in encoder.parameters(): |
|
param.requires_grad = False |
|
encoder.eval() |
|
|
|
cbet_model = hydra.utils.instantiate(cfg.model).to(cfg.device) |
|
optimizer = cbet_model.configure_optimizers( |
|
weight_decay=cfg.optim.weight_decay, |
|
learning_rate=cfg.optim.lr, |
|
betas=cfg.optim.betas, |
|
) |
|
env = hydra.utils.instantiate(cfg.env.gym) |
|
if "use_libero_goal" in cfg.data: |
|
with torch.no_grad(): |
|
|
|
goals_cache = [] |
|
for i in range(10): |
|
idx = i * 50 |
|
last_obs, _, _ = dataset.get_frames(idx, [-1]) |
|
last_obs = last_obs.to(cfg.device) |
|
embd = encoder(last_obs)[0] |
|
embd = einops.rearrange(embd, "V E -> (V E)") |
|
goals_cache.append(embd) |
|
|
|
def goal_fn(goal_idx): |
|
return goals_cache[goal_idx] |
|
else: |
|
empty_tensor = torch.zeros(1) |
|
|
|
def goal_fn(goal_idx): |
|
return empty_tensor |
|
|
|
run = wandb.init( |
|
project=cfg.wandb.project, |
|
entity=cfg.wandb.entity, |
|
config=OmegaConf.to_container(cfg, resolve=True), |
|
) |
|
run_name = run.name or "Offline" |
|
save_path = Path(cfg.save_path) / run_name |
|
save_path.mkdir(parents=True, exist_ok=False) |
|
video = VideoRecorder(dir_name=save_path) |
|
|
|
@torch.no_grad() |
|
def eval_on_env( |
|
cfg, |
|
num_evals=cfg.num_env_evals, |
|
num_eval_per_goal=1, |
|
videorecorder=None, |
|
epoch=None, |
|
): |
|
def embed(enc, obs): |
|
obs = ( |
|
torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(cfg.device) |
|
) |
|
result = enc(obs) |
|
result = einops.rearrange(result, "1 V E -> (V E)") |
|
return result |
|
|
|
avg_reward = 0 |
|
action_list = [] |
|
completion_id_list = [] |
|
avg_max_coverage = [] |
|
avg_final_coverage = [] |
|
env.seed(cfg.seed) |
|
for goal_idx in range(num_evals): |
|
if videorecorder is not None: |
|
videorecorder.init(enabled=True) |
|
for i in range(num_eval_per_goal): |
|
obs_stack = deque(maxlen=cfg.eval_window_size) |
|
this_obs = env.reset(goal_idx=goal_idx) |
|
assert ( |
|
this_obs.min() >= 0 and this_obs.max() <= 1 |
|
), "expect 0-1 range observation" |
|
this_obs_enc = embed(encoder, this_obs) |
|
obs_stack.append(this_obs_enc) |
|
done, step, total_reward = False, 0, 0 |
|
goal = goal_fn(goal_idx) |
|
while not done: |
|
obs = torch.stack(tuple(obs_stack)).float().to(cfg.device) |
|
goal = torch.as_tensor(goal, dtype=torch.float32, device=cfg.device) |
|
|
|
goal = goal.unsqueeze(0).repeat(cfg.eval_window_size, 1) |
|
action, _, _ = cbet_model(obs.unsqueeze(0), goal.unsqueeze(0), None) |
|
action = action[0] |
|
if cfg.action_window_size > 1: |
|
action_list.append(action[-1].cpu().detach().numpy()) |
|
if len(action_list) > cfg.action_window_size: |
|
action_list = action_list[1:] |
|
curr_action = np.array(action_list) |
|
curr_action = ( |
|
np.sum(curr_action, axis=0)[0] / curr_action.shape[0] |
|
) |
|
new_action_list = [] |
|
for a_chunk in action_list: |
|
new_action_list.append( |
|
np.concatenate( |
|
(a_chunk[1:], np.zeros((1, a_chunk.shape[1]))) |
|
) |
|
) |
|
action_list = new_action_list |
|
else: |
|
curr_action = action[-1, 0, :].cpu().detach().numpy() |
|
|
|
this_obs, reward, done, info = env.step(curr_action) |
|
this_obs_enc = embed(encoder, this_obs) |
|
obs_stack.append(this_obs_enc) |
|
|
|
if videorecorder.enabled: |
|
videorecorder.record(info["image"]) |
|
step += 1 |
|
total_reward += reward |
|
goal = goal_fn(goal_idx) |
|
avg_reward += total_reward |
|
if cfg.env.gym.id == "pusht": |
|
env.env._seed += 1 |
|
avg_max_coverage.append(info["max_coverage"]) |
|
avg_final_coverage.append(info["final_coverage"]) |
|
elif cfg.env.gym.id == "blockpush": |
|
avg_max_coverage.append(info["moved"]) |
|
avg_final_coverage.append(info["entered"]) |
|
completion_id_list.append(info["all_completions_ids"]) |
|
videorecorder.save("eval_{}_{}.mp4".format(epoch, goal_idx)) |
|
return ( |
|
avg_reward / (num_evals * num_eval_per_goal), |
|
completion_id_list, |
|
avg_max_coverage, |
|
avg_final_coverage, |
|
) |
|
|
|
metrics_history = [] |
|
reward_history = [] |
|
for epoch in tqdm.trange(cfg.epochs): |
|
cbet_model.eval() |
|
if epoch % cfg.eval_on_env_freq == 0: |
|
avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( |
|
cfg, |
|
videorecorder=video, |
|
epoch=epoch, |
|
num_eval_per_goal=cfg.num_final_eval_per_goal, |
|
) |
|
reward_history.append(avg_reward) |
|
with open("{}/completion_idx_{}.json".format(save_path, epoch), "wb") as fp: |
|
pickle.dump(completion_id_list, fp) |
|
wandb.log({"eval_on_env": avg_reward}) |
|
if cfg.env.gym.id in ["pusht", "blockpush"]: |
|
metric_final = ( |
|
"final coverage" if cfg.env.gym.id == "pusht" else "entered" |
|
) |
|
metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" |
|
metrics = { |
|
f"{metric_final} mean": sum(final_coverage) / len(final_coverage), |
|
f"{metric_final} max": max(final_coverage), |
|
f"{metric_final} min": min(final_coverage), |
|
f"{metric_max} mean": sum(max_coverage) / len(max_coverage), |
|
f"{metric_max} max": max(max_coverage), |
|
f"{metric_max} min": min(max_coverage), |
|
} |
|
wandb.log(metrics) |
|
metrics_history.append(metrics) |
|
|
|
if epoch % cfg.eval_freq == 0: |
|
total_loss = 0 |
|
action_diff = 0 |
|
action_diff_tot = 0 |
|
action_diff_mean_res1 = 0 |
|
action_diff_mean_res2 = 0 |
|
action_diff_max = 0 |
|
with torch.no_grad(): |
|
for data in test_loader: |
|
obs, act, goal = (x.to(cfg.device) for x in data) |
|
assert obs.ndim == 4, "expect N T V E here" |
|
obs = einops.rearrange(obs, "N T V E -> N T (V E)") |
|
goal = einops.rearrange(goal, "N T V E -> N T (V E)") |
|
predicted_act, loss, loss_dict = cbet_model(obs, goal, act) |
|
total_loss += loss.item() |
|
wandb.log({"eval/{}".format(x): y for (x, y) in loss_dict.items()}) |
|
action_diff += loss_dict["action_diff"] |
|
action_diff_tot += loss_dict["action_diff_tot"] |
|
action_diff_mean_res1 += loss_dict["action_diff_mean_res1"] |
|
action_diff_mean_res2 += loss_dict["action_diff_mean_res2"] |
|
action_diff_max += loss_dict["action_diff_max"] |
|
print(f"Test loss: {total_loss / len(test_loader)}") |
|
wandb.log({"eval/epoch_wise_action_diff": action_diff}) |
|
wandb.log({"eval/epoch_wise_action_diff_tot": action_diff_tot}) |
|
wandb.log({"eval/epoch_wise_action_diff_mean_res1": action_diff_mean_res1}) |
|
wandb.log({"eval/epoch_wise_action_diff_mean_res2": action_diff_mean_res2}) |
|
wandb.log({"eval/epoch_wise_action_diff_max": action_diff_max}) |
|
|
|
cbet_model.train() |
|
for data in tqdm.tqdm(train_loader): |
|
optimizer.zero_grad() |
|
obs, act, goal = (x.to(cfg.device) for x in data) |
|
obs = einops.rearrange(obs, "N T V E -> N T (V E)") |
|
goal = einops.rearrange(goal, "N T V E -> N T (V E)") |
|
predicted_act, loss, loss_dict = cbet_model(obs, goal, act) |
|
wandb.log({"train/{}".format(x): y for (x, y) in loss_dict.items()}) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( |
|
cfg, |
|
num_evals=cfg.num_final_evals, |
|
num_eval_per_goal=cfg.num_final_eval_per_goal, |
|
videorecorder=video, |
|
epoch=cfg.epochs, |
|
) |
|
reward_history.append(avg_reward) |
|
if cfg.env.gym.id in ["pusht", "blockpush"]: |
|
metric_final = "final coverage" if cfg.env.gym.id == "pusht" else "entered" |
|
metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" |
|
metrics = { |
|
f"{metric_final} mean": sum(final_coverage) / len(final_coverage), |
|
f"{metric_final} max": max(final_coverage), |
|
f"{metric_final} min": min(final_coverage), |
|
f"{metric_max} mean": sum(max_coverage) / len(max_coverage), |
|
f"{metric_max} max": max(max_coverage), |
|
f"{metric_max} min": min(max_coverage), |
|
} |
|
wandb.log(metrics) |
|
metrics_history.append(metrics) |
|
|
|
with open("{}/completion_idx_final.json".format(save_path), "wb") as fp: |
|
pickle.dump(completion_id_list, fp) |
|
if cfg.env.gym.id == "pusht": |
|
final_eval_on_env = max([x["final coverage mean"] for x in metrics_history]) |
|
elif cfg.env.gym.id == "blockpush": |
|
final_eval_on_env = max([x["entered mean"] for x in metrics_history]) |
|
elif cfg.env.gym.id == "libero_goal": |
|
final_eval_on_env = max(reward_history) |
|
elif cfg.env.gym.id == "kitchen-v0": |
|
final_eval_on_env = avg_reward |
|
wandb.log({"final_eval_on_env": final_eval_on_env}) |
|
return final_eval_on_env |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|