|
30 | 30 | from mlagents_envs.side_channel.side_channel import SideChannel |
31 | 31 | from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig |
32 | 32 | from mlagents_envs.exception import UnityEnvironmentException |
| 33 | +from mlagents_envs.timers import hierarchical_timer |
33 | 34 | from mlagents.logging_util import create_logger |
34 | 35 |
|
35 | 36 |
|
@@ -248,76 +249,81 @@ def run_training(run_seed: int, options: RunOptions) -> None: |
248 | 249 | :param run_seed: Random seed used for training. |
249 | 250 | :param run_options: Command line arguments for training. |
250 | 251 | """ |
251 | | - # Recognize and use docker volume if one is passed as an argument |
252 | | - if not options.docker_target_name: |
253 | | - model_path = f"./models/{options.run_id}" |
254 | | - summaries_dir = "./summaries" |
255 | | - else: |
256 | | - model_path = f"/{options.docker_target_name}/models/{options.run_id}" |
257 | | - summaries_dir = f"/{options.docker_target_name}/summaries" |
258 | | - port = options.base_port |
259 | | - |
260 | | - # Configure CSV, Tensorboard Writers and StatsReporter |
261 | | - # We assume reward and episode length are needed in the CSV. |
262 | | - csv_writer = CSVWriter( |
263 | | - summaries_dir, |
264 | | - required_fields=["Environment/Cumulative Reward", "Environment/Episode Length"], |
265 | | - ) |
266 | | - tb_writer = TensorboardWriter(summaries_dir) |
267 | | - gauge_write = GaugeWriter() |
268 | | - StatsReporter.add_writer(tb_writer) |
269 | | - StatsReporter.add_writer(csv_writer) |
270 | | - StatsReporter.add_writer(gauge_write) |
271 | | - |
272 | | - if options.env_path is None: |
273 | | - port = UnityEnvironment.DEFAULT_EDITOR_PORT |
274 | | - env_factory = create_environment_factory( |
275 | | - options.env_path, |
276 | | - options.docker_target_name, |
277 | | - options.no_graphics, |
278 | | - run_seed, |
279 | | - port, |
280 | | - options.env_args, |
281 | | - ) |
282 | | - engine_config = EngineConfig( |
283 | | - options.width, |
284 | | - options.height, |
285 | | - options.quality_level, |
286 | | - options.time_scale, |
287 | | - options.target_frame_rate, |
288 | | - ) |
289 | | - env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs) |
290 | | - maybe_meta_curriculum = try_create_meta_curriculum( |
291 | | - options.curriculum_config, env_manager, options.lesson |
292 | | - ) |
293 | | - sampler_manager, resampling_interval = create_sampler_manager( |
294 | | - options.sampler_config, run_seed |
295 | | - ) |
296 | | - trainer_factory = TrainerFactory( |
297 | | - options.trainer_config, |
298 | | - summaries_dir, |
299 | | - options.run_id, |
300 | | - model_path, |
301 | | - options.keep_checkpoints, |
302 | | - options.train_model, |
303 | | - options.load_model, |
304 | | - run_seed, |
305 | | - maybe_meta_curriculum, |
306 | | - options.multi_gpu, |
307 | | - ) |
308 | | - # Create controller and begin training. |
309 | | - tc = TrainerController( |
310 | | - trainer_factory, |
311 | | - model_path, |
312 | | - summaries_dir, |
313 | | - options.run_id, |
314 | | - options.save_freq, |
315 | | - maybe_meta_curriculum, |
316 | | - options.train_model, |
317 | | - run_seed, |
318 | | - sampler_manager, |
319 | | - resampling_interval, |
320 | | - ) |
| 252 | + with hierarchical_timer("run_training.setup"): |
| 253 | + # Recognize and use docker volume if one is passed as an argument |
| 254 | + if not options.docker_target_name: |
| 255 | + model_path = f"./models/{options.run_id}" |
| 256 | + summaries_dir = "./summaries" |
| 257 | + else: |
| 258 | + model_path = f"/{options.docker_target_name}/models/{options.run_id}" |
| 259 | + summaries_dir = f"/{options.docker_target_name}/summaries" |
| 260 | + port = options.base_port |
| 261 | + |
| 262 | + # Configure CSV, Tensorboard Writers and StatsReporter |
| 263 | + # We assume reward and episode length are needed in the CSV. |
| 264 | + csv_writer = CSVWriter( |
| 265 | + summaries_dir, |
| 266 | + required_fields=[ |
| 267 | + "Environment/Cumulative Reward", |
| 268 | + "Environment/Episode Length", |
| 269 | + ], |
| 270 | + ) |
| 271 | + tb_writer = TensorboardWriter(summaries_dir) |
| 272 | + gauge_write = GaugeWriter() |
| 273 | + StatsReporter.add_writer(tb_writer) |
| 274 | + StatsReporter.add_writer(csv_writer) |
| 275 | + StatsReporter.add_writer(gauge_write) |
| 276 | + |
| 277 | + if options.env_path is None: |
| 278 | + port = UnityEnvironment.DEFAULT_EDITOR_PORT |
| 279 | + env_factory = create_environment_factory( |
| 280 | + options.env_path, |
| 281 | + options.docker_target_name, |
| 282 | + options.no_graphics, |
| 283 | + run_seed, |
| 284 | + port, |
| 285 | + options.env_args, |
| 286 | + ) |
| 287 | + engine_config = EngineConfig( |
| 288 | + options.width, |
| 289 | + options.height, |
| 290 | + options.quality_level, |
| 291 | + options.time_scale, |
| 292 | + options.target_frame_rate, |
| 293 | + ) |
| 294 | + env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs) |
| 295 | + maybe_meta_curriculum = try_create_meta_curriculum( |
| 296 | + options.curriculum_config, env_manager, options.lesson |
| 297 | + ) |
| 298 | + sampler_manager, resampling_interval = create_sampler_manager( |
| 299 | + options.sampler_config, run_seed |
| 300 | + ) |
| 301 | + trainer_factory = TrainerFactory( |
| 302 | + options.trainer_config, |
| 303 | + summaries_dir, |
| 304 | + options.run_id, |
| 305 | + model_path, |
| 306 | + options.keep_checkpoints, |
| 307 | + options.train_model, |
| 308 | + options.load_model, |
| 309 | + run_seed, |
| 310 | + maybe_meta_curriculum, |
| 311 | + options.multi_gpu, |
| 312 | + ) |
| 313 | + # Create controller and begin training. |
| 314 | + tc = TrainerController( |
| 315 | + trainer_factory, |
| 316 | + model_path, |
| 317 | + summaries_dir, |
| 318 | + options.run_id, |
| 319 | + options.save_freq, |
| 320 | + maybe_meta_curriculum, |
| 321 | + options.train_model, |
| 322 | + run_seed, |
| 323 | + sampler_manager, |
| 324 | + resampling_interval, |
| 325 | + ) |
| 326 | + |
321 | 327 | # Begin training |
322 | 328 | try: |
323 | 329 | tc.start_learning(env_manager) |
|
0 commit comments