import einops import os import random from collections import deque from pathlib import Path import hydra import numpy as np import torch import tqdm from omegaconf import OmegaConf import wandb from utils.video import VideoRecorder import pickle from datasets.core import TrajectoryEmbeddingDataset, split_traj_datasets from datasets.vqbet_repro import TrajectorySlicerDataset if "MUJOCO_GL" not in os.environ: os.environ["MUJOCO_GL"] = "egl" def seed_everything(random_seed: int): np.random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) random.seed(random_seed) @hydra.main(config_path="eval_configs", version_base="1.2") def main(cfg): print(OmegaConf.to_yaml(cfg)) seed_everything(cfg.seed) encoder = hydra.utils.instantiate(cfg.encoder) encoder = encoder.to(cfg.device).eval() dataset = hydra.utils.instantiate(cfg.dataset) train_data, test_data = split_traj_datasets( dataset, train_fraction=cfg.train_fraction, random_seed=cfg.seed, ) use_libero_goal = cfg.data.get("use_libero_goal", False) train_data = TrajectoryEmbeddingDataset( encoder, train_data, device=cfg.device, embed_goal=use_libero_goal ) test_data = TrajectoryEmbeddingDataset( encoder, test_data, device=cfg.device, embed_goal=use_libero_goal ) traj_slicer_kwargs = { "window": cfg.data.window_size, "action_window": cfg.data.action_window_size, "vqbet_get_future_action_chunk": cfg.data.vqbet_get_future_action_chunk, "future_conditional": (cfg.data.goal_conditional == "future"), "min_future_sep": cfg.data.action_window_size, "future_seq_len": cfg.data.future_seq_len, "use_libero_goal": use_libero_goal, } train_data = TrajectorySlicerDataset(train_data, **traj_slicer_kwargs) test_data = TrajectorySlicerDataset(test_data, **traj_slicer_kwargs) train_loader = torch.utils.data.DataLoader( train_data, batch_size=cfg.batch_size, shuffle=True, pin_memory=False ) test_loader = torch.utils.data.DataLoader( test_data, batch_size=cfg.batch_size, shuffle=False, pin_memory=False ) for param in encoder.parameters(): param.requires_grad = False encoder.eval() cbet_model = hydra.utils.instantiate(cfg.model).to(cfg.device) optimizer = cbet_model.configure_optimizers( weight_decay=cfg.optim.weight_decay, learning_rate=cfg.optim.lr, betas=cfg.optim.betas, ) env = hydra.utils.instantiate(cfg.env.gym) if "use_libero_goal" in cfg.data: with torch.no_grad(): # calculate goal embeddings for each task goals_cache = [] for i in range(10): idx = i * 50 last_obs, _, _ = dataset.get_frames(idx, [-1]) # 1 V C H W last_obs = last_obs.to(cfg.device) embd = encoder(last_obs)[0] # V E embd = einops.rearrange(embd, "V E -> (V E)") goals_cache.append(embd) def goal_fn(goal_idx): return goals_cache[goal_idx] else: empty_tensor = torch.zeros(1) def goal_fn(goal_idx): return empty_tensor run = wandb.init( project=cfg.wandb.project, entity=cfg.wandb.entity, config=OmegaConf.to_container(cfg, resolve=True), ) run_name = run.name or "Offline" save_path = Path(cfg.save_path) / run_name save_path.mkdir(parents=True, exist_ok=False) video = VideoRecorder(dir_name=save_path) @torch.no_grad() def eval_on_env( cfg, num_evals=cfg.num_env_evals, num_eval_per_goal=1, videorecorder=None, epoch=None, ): def embed(enc, obs): obs = ( torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(cfg.device) ) # 1 V C H W result = enc(obs) result = einops.rearrange(result, "1 V E -> (V E)") return result avg_reward = 0 action_list = [] completion_id_list = [] avg_max_coverage = [] avg_final_coverage = [] env.seed(cfg.seed) for goal_idx in range(num_evals): if videorecorder is not None: videorecorder.init(enabled=True) for i in range(num_eval_per_goal): obs_stack = deque(maxlen=cfg.eval_window_size) this_obs = env.reset(goal_idx=goal_idx) # V C H W assert ( this_obs.min() >= 0 and this_obs.max() <= 1 ), "expect 0-1 range observation" this_obs_enc = embed(encoder, this_obs) obs_stack.append(this_obs_enc) done, step, total_reward = False, 0, 0 goal = goal_fn(goal_idx) # V C H W while not done: obs = torch.stack(tuple(obs_stack)).float().to(cfg.device) goal = torch.as_tensor(goal, dtype=torch.float32, device=cfg.device) # goal = embed(encoder, goal) goal = goal.unsqueeze(0).repeat(cfg.eval_window_size, 1) action, _, _ = cbet_model(obs.unsqueeze(0), goal.unsqueeze(0), None) action = action[0] # remove batch dim; always 1 if cfg.action_window_size > 1: action_list.append(action[-1].cpu().detach().numpy()) if len(action_list) > cfg.action_window_size: action_list = action_list[1:] curr_action = np.array(action_list) curr_action = ( np.sum(curr_action, axis=0)[0] / curr_action.shape[0] ) new_action_list = [] for a_chunk in action_list: new_action_list.append( np.concatenate( (a_chunk[1:], np.zeros((1, a_chunk.shape[1]))) ) ) action_list = new_action_list else: curr_action = action[-1, 0, :].cpu().detach().numpy() this_obs, reward, done, info = env.step(curr_action) this_obs_enc = embed(encoder, this_obs) obs_stack.append(this_obs_enc) if videorecorder.enabled: videorecorder.record(info["image"]) step += 1 total_reward += reward goal = goal_fn(goal_idx) avg_reward += total_reward if cfg.env.gym.id == "pusht": env.env._seed += 1 avg_max_coverage.append(info["max_coverage"]) avg_final_coverage.append(info["final_coverage"]) elif cfg.env.gym.id == "blockpush": avg_max_coverage.append(info["moved"]) avg_final_coverage.append(info["entered"]) completion_id_list.append(info["all_completions_ids"]) videorecorder.save("eval_{}_{}.mp4".format(epoch, goal_idx)) return ( avg_reward / (num_evals * num_eval_per_goal), completion_id_list, avg_max_coverage, avg_final_coverage, ) metrics_history = [] reward_history = [] for epoch in tqdm.trange(cfg.epochs): cbet_model.eval() if epoch % cfg.eval_on_env_freq == 0: avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( cfg, videorecorder=video, epoch=epoch, num_eval_per_goal=cfg.num_final_eval_per_goal, ) reward_history.append(avg_reward) with open("{}/completion_idx_{}.json".format(save_path, epoch), "wb") as fp: pickle.dump(completion_id_list, fp) wandb.log({"eval_on_env": avg_reward}) if cfg.env.gym.id in ["pusht", "blockpush"]: metric_final = ( "final coverage" if cfg.env.gym.id == "pusht" else "entered" ) metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" metrics = { f"{metric_final} mean": sum(final_coverage) / len(final_coverage), f"{metric_final} max": max(final_coverage), f"{metric_final} min": min(final_coverage), f"{metric_max} mean": sum(max_coverage) / len(max_coverage), f"{metric_max} max": max(max_coverage), f"{metric_max} min": min(max_coverage), } wandb.log(metrics) metrics_history.append(metrics) if epoch % cfg.eval_freq == 0: total_loss = 0 action_diff = 0 action_diff_tot = 0 action_diff_mean_res1 = 0 action_diff_mean_res2 = 0 action_diff_max = 0 with torch.no_grad(): for data in test_loader: obs, act, goal = (x.to(cfg.device) for x in data) assert obs.ndim == 4, "expect N T V E here" obs = einops.rearrange(obs, "N T V E -> N T (V E)") goal = einops.rearrange(goal, "N T V E -> N T (V E)") predicted_act, loss, loss_dict = cbet_model(obs, goal, act) total_loss += loss.item() wandb.log({"eval/{}".format(x): y for (x, y) in loss_dict.items()}) action_diff += loss_dict["action_diff"] action_diff_tot += loss_dict["action_diff_tot"] action_diff_mean_res1 += loss_dict["action_diff_mean_res1"] action_diff_mean_res2 += loss_dict["action_diff_mean_res2"] action_diff_max += loss_dict["action_diff_max"] print(f"Test loss: {total_loss / len(test_loader)}") wandb.log({"eval/epoch_wise_action_diff": action_diff}) wandb.log({"eval/epoch_wise_action_diff_tot": action_diff_tot}) wandb.log({"eval/epoch_wise_action_diff_mean_res1": action_diff_mean_res1}) wandb.log({"eval/epoch_wise_action_diff_mean_res2": action_diff_mean_res2}) wandb.log({"eval/epoch_wise_action_diff_max": action_diff_max}) cbet_model.train() for data in tqdm.tqdm(train_loader): optimizer.zero_grad() obs, act, goal = (x.to(cfg.device) for x in data) obs = einops.rearrange(obs, "N T V E -> N T (V E)") goal = einops.rearrange(goal, "N T V E -> N T (V E)") predicted_act, loss, loss_dict = cbet_model(obs, goal, act) wandb.log({"train/{}".format(x): y for (x, y) in loss_dict.items()}) loss.backward() optimizer.step() avg_reward, completion_id_list, max_coverage, final_coverage = eval_on_env( cfg, num_evals=cfg.num_final_evals, num_eval_per_goal=cfg.num_final_eval_per_goal, videorecorder=video, epoch=cfg.epochs, ) reward_history.append(avg_reward) if cfg.env.gym.id in ["pusht", "blockpush"]: metric_final = "final coverage" if cfg.env.gym.id == "pusht" else "entered" metric_max = "max coverage" if cfg.env.gym.id == "pusht" else "moved" metrics = { f"{metric_final} mean": sum(final_coverage) / len(final_coverage), f"{metric_final} max": max(final_coverage), f"{metric_final} min": min(final_coverage), f"{metric_max} mean": sum(max_coverage) / len(max_coverage), f"{metric_max} max": max(max_coverage), f"{metric_max} min": min(max_coverage), } wandb.log(metrics) metrics_history.append(metrics) with open("{}/completion_idx_final.json".format(save_path), "wb") as fp: pickle.dump(completion_id_list, fp) if cfg.env.gym.id == "pusht": final_eval_on_env = max([x["final coverage mean"] for x in metrics_history]) elif cfg.env.gym.id == "blockpush": final_eval_on_env = max([x["entered mean"] for x in metrics_history]) elif cfg.env.gym.id == "libero_goal": final_eval_on_env = max(reward_history) elif cfg.env.gym.id == "kitchen-v0": final_eval_on_env = avg_reward wandb.log({"final_eval_on_env": final_eval_on_env}) return final_eval_on_env if __name__ == "__main__": main()