11# Copyright The Lightning AI team.
22# Licensed under the Apache License, Version 2.0 (the "License");
3- # you may not use this file except in compliance with the License.
3+ # You may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
55#
66# http://www.apache.org/licenses/LICENSE-2.0
1313
1414import logging
1515import os
16- import sys
16+ import re
17+ import threading
18+ import time
1719from functools import lru_cache
1820
19- from litdata .constants import _PRINT_DEBUG_LOGS
20- from litdata .utilities .env import _DistributedEnv , _WorkerEnv
21+ from litdata .utilities .env import _DistributedEnv , _is_in_dataloader_worker , _WorkerEnv
2122
22- # Create the root logger for the library
23- root_logger = logging .getLogger ("litdata" )
23+
24+ class TimedFlushFileHandler (logging .FileHandler ):
25+ """FileHandler that flushes every N seconds in a background thread."""
26+
27+ def __init__ (self , filename , mode = "a" , flush_interval = 2 ):
28+ super ().__init__ (filename , mode )
29+ self .flush_interval = flush_interval
30+ self ._stop_event = threading .Event ()
31+ t = threading .Thread (target = self ._flusher , daemon = True , name = "TimedFlushFileHandler._flusher" )
32+ t .start ()
33+
34+ def _flusher (self ):
35+ while not self ._stop_event .is_set ():
36+ time .sleep (self .flush_interval )
37+ self .flush ()
38+
39+ def close (self ):
40+ self ._stop_event .set ()
41+ self .flush ()
42+ super ().close ()
43+
44+
45+ class EnvConfigFilter (logging .Filter ):
46+ """A logging filter that reads its configuration from environment variables."""
47+
48+ def __init__ (self ):
49+ super ().__init__ ()
50+ self .name_re = re .compile (r"name:\s*([^;]+);" )
51+
52+ def _get_name_from_msg (self , msg ):
53+ match = self .name_re .search (msg )
54+ return match .group (1 ).strip () if match else None
55+
56+ def filter (self , record ):
57+ """Determine if a log record should be processed by checking env vars."""
58+ is_iterating_dataset_enabled = os .getenv ("LITDATA_LOG_ITERATING_DATASET" , "True" ).lower () == "true"
59+ is_getitem_enabled = os .getenv ("LITDATA_LOG_GETITEM" , "True" ).lower () == "true"
60+ is_item_loader_enabled = os .getenv ("LITDATA_LOG_ITEM_LOADER" , "True" ).lower () == "true"
61+
62+ log_name = self ._get_name_from_msg (record .getMessage ())
63+
64+ if log_name :
65+ if not is_iterating_dataset_enabled and log_name .startswith ("iterating_dataset" ):
66+ return False
67+ if not is_getitem_enabled and log_name .startswith ("getitem_dataset_for_chunk_index" ):
68+ return False
69+ if not is_item_loader_enabled and log_name .startswith ("item_loader" ):
70+ return False
71+
72+ return True
2473
2574
2675def get_logger_level (level : str ) -> int :
27- """Get the log level from the level string."""
2876 level = level .upper ()
2977 if level in logging ._nameToLevel :
3078 return logging ._nameToLevel [level ]
31- raise ValueError (f"Invalid log level: { level } . Valid levels: { list ( logging . _nameToLevel . keys ()) } . " )
79+ raise ValueError (f"Invalid log level: { level } " )
3280
3381
3482class LitDataLogger :
35- def __init__ (self , name : str ):
83+ _instance = None
84+ _lock = threading .Lock ()
85+
86+ def __new__ (cls , * args , ** kwargs ):
87+ if cls ._instance is None :
88+ with cls ._lock :
89+ if cls ._instance is None :
90+ cls ._instance = super ().__new__ (cls )
91+ return cls ._instance
92+
93+ def __init__ (self , name = "litdata" , flush_interval = 2 ):
94+ if hasattr (self , "logger" ):
95+ return # Already initialized
96+
3697 self .logger = logging .getLogger (name )
98+ self .logger .propagate = False
3799 self .log_file , self .log_level = self .get_log_file_and_level ()
38- self .setup_logger ()
100+ self .flush_interval = flush_interval
101+ self ._setup_logger ()
39102
40103 @staticmethod
41- def get_log_file_and_level () -> tuple [ str , int ] :
104+ def get_log_file_and_level ():
42105 log_file = os .getenv ("LITDATA_LOG_FILE" , "litdata_debug.log" )
43106 log_lvl = os .getenv ("LITDATA_LOG_LEVEL" , "DEBUG" )
107+ return log_file , get_logger_level (log_lvl )
44108
45- log_lvl = get_logger_level (log_lvl )
46-
47- return log_file , log_lvl
48-
49- def setup_logger (self ) -> None :
50- """Configures logging by adding handlers and formatting."""
51- if len (self .logger .handlers ) > 0 : # Avoid duplicate handlers
109+ def _setup_logger (self ):
110+ if self .logger .handlers :
52111 return
53-
54112 self .logger .setLevel (self .log_level )
113+ formatter = logging .Formatter ("ts:%(created)s;PID:%(process)d; TID:%(thread)d; %(message)s" )
114+ handler = TimedFlushFileHandler (self .log_file , flush_interval = self .flush_interval )
115+ handler .setFormatter (formatter )
116+ handler .setLevel (self .log_level )
117+ self .logger .addHandler (handler )
55118
56- # Console handler
57- console_handler = logging .StreamHandler (sys .stdout )
58- console_handler .setLevel (self .log_level )
59-
60- # File handler
61- file_handler = logging .FileHandler (self .log_file )
62- file_handler .setLevel (self .log_level )
119+ self .logger .filters = [f for f in self .logger .filters if not isinstance (f , EnvConfigFilter )]
120+ self .logger .addFilter (EnvConfigFilter ())
63121
64- # Log format
65- formatter = logging .Formatter (
66- "ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s"
67- )
68- # ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}"
69- console_handler .setFormatter (formatter )
70- file_handler .setFormatter (formatter )
122+ def get_logger (self ):
123+ return self .logger
71124
72- # Attach handlers
73- if _PRINT_DEBUG_LOGS :
74- self .logger .addHandler (console_handler )
75- self .logger .addHandler (file_handler )
76125
77-
78- def enable_tracer () -> None :
126+ def enable_tracer (
127+ flush_interval : int = 5 , item_loader = True , iterating_dataset = True , getitem_dataset_for_chunk_index = True
128+ ) -> logging .Logger :
129+ """Convenience function to enable and configure litdata logging.
130+ This function SETS the environment variables that control the logging behavior.
131+ """
79132 os .environ ["LITDATA_LOG_FILE" ] = "litdata_debug.log"
80- LitDataLogger ("litdata" )
133+ os .environ ["LITDATA_LOG_ITEM_LOADER" ] = str (item_loader )
134+ os .environ ["LITDATA_LOG_ITERATING_DATASET" ] = str (iterating_dataset )
135+ os .environ ["LITDATA_LOG_GETITEM" ] = str (getitem_dataset_for_chunk_index )
136+
137+ master_logger = LitDataLogger (flush_interval = flush_interval ).get_logger ()
138+ return master_logger
81139
82140
83141def _get_log_msg (data : dict ) -> str :
84142 log_msg = ""
85-
86143 if "name" not in data or "ph" not in data :
87144 raise ValueError (f"Missing required keys in data dictionary. Required keys: 'name', 'ph'. Received: { data } " )
88-
89145 env_info_data = env_info ()
90146 data .update (env_info_data )
91-
92147 for key , value in data .items ():
93148 log_msg += f"{ key } : { value } ;"
94149 return log_msg
95150
96151
97- @lru_cache (maxsize = 1 )
98152def env_info () -> dict :
99- dist_env = _DistributedEnv . detect ()
100- worker_env = _WorkerEnv . detect () # will all threads read the same value if decorate this function with `@cache`
153+ if _is_in_dataloader_worker ():
154+ return _cached_env_info ()
101155
156+ dist_env = _DistributedEnv .detect ()
157+ worker_env = _WorkerEnv .detect ()
102158 return {
103159 "dist_world_size" : dist_env .world_size ,
104160 "dist_global_rank" : dist_env .global_rank ,
@@ -108,16 +164,20 @@ def env_info() -> dict:
108164 }
109165
110166
111- # -> Chrome tracing colors
112- # url: https://chromium.googlesource.com/external/trace-viewer/+/bf55211014397cf0ebcd9e7090de1c4f84fc3ac0/tracing/tracing/ui/base/color_scheme.html
113-
114- # # ------
167+ @lru_cache (maxsize = 1 )
168+ def _cached_env_info () -> dict :
169+ dist_env = _DistributedEnv .detect ()
170+ worker_env = _WorkerEnv .detect ()
171+ return {
172+ "dist_world_size" : dist_env .world_size ,
173+ "dist_global_rank" : dist_env .global_rank ,
174+ "dist_num_nodes" : dist_env .num_nodes ,
175+ "worker_world_size" : worker_env .world_size ,
176+ "worker_rank" : worker_env .rank ,
177+ }
115178
116179
117- # thread_state_iowait: {r: 182, g: 125, b: 143},
118- # thread_state_running: {r: 126, g: 200, b: 148},
119- # thread_state_runnable: {r: 133, g: 160, b: 210},
120- # ....
180+ # Chrome trace colors
121181class ChromeTraceColors :
122182 PINK = "thread_state_iowait"
123183 GREEN = "thread_state_running"
0 commit comments