diff --git a/internnav/agent/internvla_n1_agent.py b/internnav/agent/internvla_n1_agent.py index b96ef7c..dde234b 100644 --- a/internnav/agent/internvla_n1_agent.py +++ b/internnav/agent/internvla_n1_agent.py @@ -226,15 +226,10 @@ def should_infer_s1(self, mode="sync"): raise ValueError("Invalid mode: {}".format(mode)) - def step(self, step_request: StepRequest): - def transfer(obs): - obs = base64.b64decode(obs) - obs = pickle.loads(obs) - return obs - + def step(self, obs): mode = 'sync' # 'sync', 'partial_async', 'full_async' - obs = transfer(step_request)[0] + obs = obs[0] # do not support batch_env currently? rgb = obs['rgb'] depth = obs['depth'] instruction = obs['instruction'] diff --git a/internnav/evaluator/vln_multi_evaluator.py b/internnav/evaluator/vln_multi_evaluator.py index d0dd09c..be46931 100644 --- a/internnav/evaluator/vln_multi_evaluator.py +++ b/internnav/evaluator/vln_multi_evaluator.py @@ -1,16 +1,14 @@ from enum import Enum from pathlib import Path from time import time - import numpy as np - from internnav.configs.evaluator import EvalCfg from internnav.evaluator.base import Evaluator from internnav.evaluator.utils.common import set_seed_model from internnav.evaluator.utils.config import get_lmdb_path from internnav.evaluator.utils.data_collector import DataCollector from internnav.evaluator.utils.dataset import ResultLogger, split_data -from internnav.evaluator.utils.eval import generate_episode, serialize_obs +from internnav.evaluator.utils.eval import generate_episode from internnav.projects.dataloader.resumable import ResumablePathKeyDataloader from internnav.utils import common_log_util, progress_log_multi_util from internnav.utils.common_log_util import common_logger as log @@ -51,19 +49,17 @@ class VlnMultiEvaluator(Evaluator): def __init__(self, config: EvalCfg): self.task_name = config.task.task_name if not Path(get_lmdb_path(self.task_name)).exists(): - split_data(config.dataset.dataset_settings) - self.result_logger = ResultLogger(config.dataset.dataset_settings) + split_data(config.dataset) + self.result_logger = ResultLogger(config.dataset) common_log_util.init(self.task_name) - self.dataloader = ResumablePathKeyDataloader(**config.dataset.dataset_settings) + self.dataloader = ResumablePathKeyDataloader(config.dataset.dataset_type, **config.dataset.dataset_settings) self.dataset_name = Path(config.dataset.dataset_settings['base_data_dir']).name progress_log_multi_util.init(self.task_name, self.dataloader.size) self.total_path_num = self.dataloader.size progress_log_multi_util.progress_logger_multi.info( f'start eval dataset: {self.task_name}, total_path:{self.dataloader.size}' # noqa: E501 ) - - self.robot_flash = config.task.robot_flash - + # generate episode episodes = generate_episode(self.dataloader, config) config.task.task_settings.update({'episodes': episodes}) self.env_num = config.task.task_settings['env_num'] @@ -90,6 +86,8 @@ def __init__(self, config: EvalCfg): super().__init__(config) set_seed_model(0) self.data_collector = DataCollector(self.dataloader.lmdb_path) + self.robot_flash = config.task.robot_flash + @property def ignore_obs_attr(self): @@ -135,9 +133,8 @@ def get_action(self, obs, action): ) obs[fake_obs_index] = self.fake_obs obs = self.remove_obs_attr(obs) - obs_trans = serialize_obs(obs) if not np.logical_and.reduce(self.runner_status == runner_status_code.WARM_UP): - action = self.agent.step(obs_trans) + action = self.agent.step(obs) log.info(f'now action:{len(action)} ,{action}, fake_obs_index:{fake_obs_index}') action = transform_action_batch(action, self.robot_flash) # change warm_up @@ -274,7 +271,6 @@ def eval(self): while self.env.is_running(): obs, action = self.get_action(obs, action) - print(f"step action: {action}") obs, terminated = self.env_step(action) env_term, reset_info = self.terminate_ops(obs, reset_info, terminated) if env_term: diff --git a/requirements/isaac_requirements.txt b/requirements/isaac_requirements.txt index c7580da..45f9590 100644 --- a/requirements/isaac_requirements.txt +++ b/requirements/isaac_requirements.txt @@ -176,4 +176,5 @@ websockets==12.0 wrapt==1.16.0 yarl==1.9.4 zipp==3.23.0 +gradio -e git+https://github.com/real-stanford/diffusion_policy.git@5ba07ac6661db573af695b419a7947ecb704690f#egg=diffusion_policy