diff --git a/ml-agents/mlagents/plugins/stats_writer.py b/ml-agents/mlagents/plugins/stats_writer.py index ddacff5960..17acefd32e 100644 --- a/ml-agents/mlagents/plugins/stats_writer.py +++ b/ml-agents/mlagents/plugins/stats_writer.py @@ -31,6 +31,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: TensorboardWriter( checkpoint_settings.write_path, clear_past_data=not checkpoint_settings.resume, + hidden_keys=["Is Training", "Step"], ), GaugeWriter(), ConsoleWriter(), diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py index b27c35e633..78dc33893e 100644 --- a/ml-agents/mlagents/trainers/stats.py +++ b/ml-agents/mlagents/trainers/stats.py @@ -1,6 +1,6 @@ from collections import defaultdict from enum import Enum -from typing import List, Dict, NamedTuple, Any +from typing import List, Dict, NamedTuple, Any, Optional import numpy as np import abc import os @@ -14,7 +14,6 @@ from torch.utils.tensorboard import SummaryWriter from mlagents.torch_utils.globals import get_rank - logger = get_logger(__name__) @@ -212,7 +211,12 @@ def add_property( class TensorboardWriter(StatsWriter): - def __init__(self, base_dir: str, clear_past_data: bool = False): + def __init__( + self, + base_dir: str, + clear_past_data: bool = False, + hidden_keys: Optional[List[str]] = None, + ): """ A StatsWriter that writes to a Tensorboard summary. @@ -220,16 +224,21 @@ def __init__(self, base_dir: str, clear_past_data: bool = False): {base_dir}/{category} directory. :param clear_past_data: Whether or not to clean up existing Tensorboard files associated with the base_dir and category. + :param hidden_keys: If provided, Tensorboard Writer won't write statistics identified with these Keys in + Tensorboard summary. """ self.summary_writers: Dict[str, SummaryWriter] = {} self.base_dir: str = base_dir self._clear_past_data = clear_past_data + self.hidden_keys: List[str] = hidden_keys if hidden_keys is not None else [] def write_stats( self, category: str, values: Dict[str, StatsSummary], step: int ) -> None: self._maybe_create_summary_writer(category) for key, value in values.items(): + if key in self.hidden_keys: + continue self.summary_writers[category].add_scalar( f"{key}", value.aggregated_value, step ) diff --git a/ml-agents/mlagents/trainers/tests/test_stats.py b/ml-agents/mlagents/trainers/tests/test_stats.py index ae0698d32a..0d1dd1b19d 100644 --- a/ml-agents/mlagents/trainers/tests/test_stats.py +++ b/ml-agents/mlagents/trainers/tests/test_stats.py @@ -129,6 +129,31 @@ def test_tensorboard_writer_clear(tmp_path): assert len(os.listdir(os.path.join(tmp_path, "category1"))) == 1 +@mock.patch("mlagents.trainers.stats.SummaryWriter") +def test_tensorboard_writer_hidden_keys(mock_summary): + # Test write_stats + category = "category1" + with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir: + tb_writer = TensorboardWriter( + base_dir, clear_past_data=False, hidden_keys="hiddenKey" + ) + statssummary1 = StatsSummary( + full_dist=[1.0], aggregation_method=StatsAggregationMethod.AVERAGE + ) + tb_writer.write_stats("category1", {"hiddenKey": statssummary1}, 10) + + # Test that the filewriter has been created and the directory has been created. + filewriter_dir = "{basedir}/{category}".format( + basedir=base_dir, category=category + ) + assert os.path.exists(filewriter_dir) + mock_summary.assert_called_once_with(filewriter_dir) + + # Test that the filewriter was not written to since we used the hidden key. + mock_summary.return_value.add_scalar.assert_not_called() + mock_summary.return_value.flush.assert_not_called() + + def test_gauge_stat_writer_sanitize(): assert GaugeWriter.sanitize_string("Policy/Learning Rate") == "Policy.LearningRate" assert ( diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index 78a9d7b130..59ed2bbedc 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -211,6 +211,7 @@ def _increment_step(self, n_steps: int, name_behavior_id: str) -> None: p = self.get_policy(name_behavior_id) if p: p.increment_step(n_steps) + self.stats_reporter.set_stat("Step", float(self.get_step)) def _get_next_interval_step(self, interval: int) -> int: """