Skip to content

Commit 362eec2

Browse files
robTheBuildrpre-commit-ci[bot]tchaton
authored
Better support for streaming optimized dataset (#727)
* Fix: Force delete prior to force download * fix * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * update * update * update * update * update * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton <[email protected]>
1 parent df92bf8 commit 362eec2

File tree

16 files changed

+276
-468
lines changed

16 files changed

+276
-468
lines changed

README.md

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,11 @@ storage_options = {
341341

342342
dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
343343

344-
# s5cmd compatible storage options for a custom S3-compatible endpoint
345-
# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1`
346-
storage_options = {
347-
"AWS_ACCESS_KEY_ID": "your_access_key_id",
348-
"AWS_SECRET_ACCESS_KEY": "your_secret_access_key",
349-
"S3_ENDPOINT_URL": "your_endpoint_url", # Required only for custom endpoints
350-
}
351344

352345

353346
dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
354347
```
355348

356-
Alternative: Using `s5cmd` for S3 Operations
357-
358-
359349
Also, you can specify a custom cache directory when initializing your dataset. This is useful when you want to store the cache in a specific location.
360350
```python
361351
from litdata import StreamingDataset
@@ -543,21 +533,13 @@ aws_storage_options={
543533
}
544534
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
545535

546-
547-
# Read data from AWS S3 using s5cmd
548-
# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1`
549536
aws_storage_options={
550537
"AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'],
551538
"AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'],
552539
"S3_ENDPOINT_URL": os.environ['AWS_ENDPOINT_URL'], # Required only for custom endpoints
553540
}
554541
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
555542

556-
# Read Data from AWS S3 with Unsigned Request using s5cmd
557-
aws_storage_options={
558-
"AWS_NO_SIGN_REQUEST": "Yes" # Required for unsigned requests
559-
"S3_ENDPOINT_URL": os.environ['AWS_ENDPOINT_URL'], # Required only for custom endpoints
560-
}
561543
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)
562544

563545

pyproject.toml

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ line-length = 120
2525
exclude = [
2626
".git",
2727
"docs",
28+
"src/litdata/debugger.py",
2829
"src/litdata/utilities/_pytree.py",
2930
]
3031
# Enable Pyflakes `E` and `F` codes by default.
@@ -65,20 +66,22 @@ lint.per-file-ignores."examples/**" = [
6566
]
6667
lint.per-file-ignores."setup.py" = [ "D100", "SIM115" ]
6768
lint.per-file-ignores."src/**" = [
68-
"D100", # Missing docstring in public module
69-
"D101", # todo: Missing docstring in public class
70-
"D102", # todo: Missing docstring in public method
71-
"D103", # todo: Missing docstring in public function
72-
"D104", # Missing docstring in public package
73-
"D105", # todo: Missing docstring in magic method
74-
"D107", # todo: Missing docstring in __init__
75-
"D205", # todo: 1 blank line required between summary line and description
69+
"D100", # Missing docstring in public module
70+
"D101", # todo: Missing docstring in public class
71+
"D102", # todo: Missing docstring in public method
72+
"D103", # todo: Missing docstring in public function
73+
"D104", # Missing docstring in public package
74+
"D105", # todo: Missing docstring in magic method
75+
"D107", # todo: Missing docstring in __init__
76+
"D205", # todo: 1 blank line required between summary line and description
7677
"D401",
77-
"D404", # todo: First line should be in imperative mood; try rephrasing
78-
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected.
79-
"S602", # todo: `subprocess` call with `shell=True` identified, security issue
80-
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
81-
"S607", # todo: Starting a process with a partial executable path
78+
"D404", # todo: First line should be in imperative mood; try rephrasing
79+
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected.
80+
"S602", # todo: `subprocess` call with `shell=True` identified, security issue
81+
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
82+
"S607", # todo: Starting a process with a partial executable path
83+
"UP006", # UP006 Use `list` instead of `List` for type annotation
84+
"UP035", # UP035 `typing.Tuple` is deprecated, use `tuple` instead
8285
]
8386
lint.per-file-ignores."tests/**" = [
8487
"D100",
@@ -166,6 +169,7 @@ exclude = [
166169
"src/litdata/imports.py",
167170
"src/litdata/imports.py",
168171
"src/litdata/processing/data_processor.py",
172+
"src/litdata/debugger.py",
169173
]
170174
install_types = "True"
171175
non_interactive = "True"

requirements/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,4 @@ polars >1.0.0
1616
lightning
1717
transformers <4.53.0
1818
zstd
19-
s5cmd >=0.2.0
2019
soundfile >=0.13.0 # required for torchaudio backend

src/litdata/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555

5656
_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
5757
_FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30"))
58-
_DISABLE_S5CMD = bool(int(os.getenv("DISABLE_S5CMD", "0")))
5958

6059
# DON'T CHANGE ORDER
6160
_TORCH_DTYPES_MAPPING = {

src/litdata/debugger.py

Lines changed: 114 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright The Lightning AI team.
22
# Licensed under the Apache License, Version 2.0 (the "License");
3-
# you may not use this file except in compliance with the License.
3+
# You may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
55
#
66
# http://www.apache.org/licenses/LICENSE-2.0
@@ -13,92 +13,148 @@
1313

1414
import logging
1515
import os
16-
import sys
16+
import re
17+
import threading
18+
import time
1719
from functools import lru_cache
1820

19-
from litdata.constants import _PRINT_DEBUG_LOGS
20-
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
21+
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
2122

22-
# Create the root logger for the library
23-
root_logger = logging.getLogger("litdata")
23+
24+
class TimedFlushFileHandler(logging.FileHandler):
25+
"""FileHandler that flushes every N seconds in a background thread."""
26+
27+
def __init__(self, filename, mode="a", flush_interval=2):
28+
super().__init__(filename, mode)
29+
self.flush_interval = flush_interval
30+
self._stop_event = threading.Event()
31+
t = threading.Thread(target=self._flusher, daemon=True, name="TimedFlushFileHandler._flusher")
32+
t.start()
33+
34+
def _flusher(self):
35+
while not self._stop_event.is_set():
36+
time.sleep(self.flush_interval)
37+
self.flush()
38+
39+
def close(self):
40+
self._stop_event.set()
41+
self.flush()
42+
super().close()
43+
44+
45+
class EnvConfigFilter(logging.Filter):
46+
"""A logging filter that reads its configuration from environment variables."""
47+
48+
def __init__(self):
49+
super().__init__()
50+
self.name_re = re.compile(r"name:\s*([^;]+);")
51+
52+
def _get_name_from_msg(self, msg):
53+
match = self.name_re.search(msg)
54+
return match.group(1).strip() if match else None
55+
56+
def filter(self, record):
57+
"""Determine if a log record should be processed by checking env vars."""
58+
is_iterating_dataset_enabled = os.getenv("LITDATA_LOG_ITERATING_DATASET", "True").lower() == "true"
59+
is_getitem_enabled = os.getenv("LITDATA_LOG_GETITEM", "True").lower() == "true"
60+
is_item_loader_enabled = os.getenv("LITDATA_LOG_ITEM_LOADER", "True").lower() == "true"
61+
62+
log_name = self._get_name_from_msg(record.getMessage())
63+
64+
if log_name:
65+
if not is_iterating_dataset_enabled and log_name.startswith("iterating_dataset"):
66+
return False
67+
if not is_getitem_enabled and log_name.startswith("getitem_dataset_for_chunk_index"):
68+
return False
69+
if not is_item_loader_enabled and log_name.startswith("item_loader"):
70+
return False
71+
72+
return True
2473

2574

2675
def get_logger_level(level: str) -> int:
27-
"""Get the log level from the level string."""
2876
level = level.upper()
2977
if level in logging._nameToLevel:
3078
return logging._nameToLevel[level]
31-
raise ValueError(f"Invalid log level: {level}. Valid levels: {list(logging._nameToLevel.keys())}.")
79+
raise ValueError(f"Invalid log level: {level}")
3280

3381

3482
class LitDataLogger:
35-
def __init__(self, name: str):
83+
_instance = None
84+
_lock = threading.Lock()
85+
86+
def __new__(cls, *args, **kwargs):
87+
if cls._instance is None:
88+
with cls._lock:
89+
if cls._instance is None:
90+
cls._instance = super().__new__(cls)
91+
return cls._instance
92+
93+
def __init__(self, name="litdata", flush_interval=2):
94+
if hasattr(self, "logger"):
95+
return # Already initialized
96+
3697
self.logger = logging.getLogger(name)
98+
self.logger.propagate = False
3799
self.log_file, self.log_level = self.get_log_file_and_level()
38-
self.setup_logger()
100+
self.flush_interval = flush_interval
101+
self._setup_logger()
39102

40103
@staticmethod
41-
def get_log_file_and_level() -> tuple[str, int]:
104+
def get_log_file_and_level():
42105
log_file = os.getenv("LITDATA_LOG_FILE", "litdata_debug.log")
43106
log_lvl = os.getenv("LITDATA_LOG_LEVEL", "DEBUG")
107+
return log_file, get_logger_level(log_lvl)
44108

45-
log_lvl = get_logger_level(log_lvl)
46-
47-
return log_file, log_lvl
48-
49-
def setup_logger(self) -> None:
50-
"""Configures logging by adding handlers and formatting."""
51-
if len(self.logger.handlers) > 0: # Avoid duplicate handlers
109+
def _setup_logger(self):
110+
if self.logger.handlers:
52111
return
53-
54112
self.logger.setLevel(self.log_level)
113+
formatter = logging.Formatter("ts:%(created)s;PID:%(process)d; TID:%(thread)d; %(message)s")
114+
handler = TimedFlushFileHandler(self.log_file, flush_interval=self.flush_interval)
115+
handler.setFormatter(formatter)
116+
handler.setLevel(self.log_level)
117+
self.logger.addHandler(handler)
55118

56-
# Console handler
57-
console_handler = logging.StreamHandler(sys.stdout)
58-
console_handler.setLevel(self.log_level)
59-
60-
# File handler
61-
file_handler = logging.FileHandler(self.log_file)
62-
file_handler.setLevel(self.log_level)
119+
self.logger.filters = [f for f in self.logger.filters if not isinstance(f, EnvConfigFilter)]
120+
self.logger.addFilter(EnvConfigFilter())
63121

64-
# Log format
65-
formatter = logging.Formatter(
66-
"ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s"
67-
)
68-
# ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}"
69-
console_handler.setFormatter(formatter)
70-
file_handler.setFormatter(formatter)
122+
def get_logger(self):
123+
return self.logger
71124

72-
# Attach handlers
73-
if _PRINT_DEBUG_LOGS:
74-
self.logger.addHandler(console_handler)
75-
self.logger.addHandler(file_handler)
76125

77-
78-
def enable_tracer() -> None:
126+
def enable_tracer(
127+
flush_interval: int = 5, item_loader=True, iterating_dataset=True, getitem_dataset_for_chunk_index=True
128+
) -> logging.Logger:
129+
"""Convenience function to enable and configure litdata logging.
130+
This function SETS the environment variables that control the logging behavior.
131+
"""
79132
os.environ["LITDATA_LOG_FILE"] = "litdata_debug.log"
80-
LitDataLogger("litdata")
133+
os.environ["LITDATA_LOG_ITEM_LOADER"] = str(item_loader)
134+
os.environ["LITDATA_LOG_ITERATING_DATASET"] = str(iterating_dataset)
135+
os.environ["LITDATA_LOG_GETITEM"] = str(getitem_dataset_for_chunk_index)
136+
137+
master_logger = LitDataLogger(flush_interval=flush_interval).get_logger()
138+
return master_logger
81139

82140

83141
def _get_log_msg(data: dict) -> str:
84142
log_msg = ""
85-
86143
if "name" not in data or "ph" not in data:
87144
raise ValueError(f"Missing required keys in data dictionary. Required keys: 'name', 'ph'. Received: {data}")
88-
89145
env_info_data = env_info()
90146
data.update(env_info_data)
91-
92147
for key, value in data.items():
93148
log_msg += f"{key}: {value};"
94149
return log_msg
95150

96151

97-
@lru_cache(maxsize=1)
98152
def env_info() -> dict:
99-
dist_env = _DistributedEnv.detect()
100-
worker_env = _WorkerEnv.detect() # will all threads read the same value if decorate this function with `@cache`
153+
if _is_in_dataloader_worker():
154+
return _cached_env_info()
101155

156+
dist_env = _DistributedEnv.detect()
157+
worker_env = _WorkerEnv.detect()
102158
return {
103159
"dist_world_size": dist_env.world_size,
104160
"dist_global_rank": dist_env.global_rank,
@@ -108,16 +164,20 @@ def env_info() -> dict:
108164
}
109165

110166

111-
# -> Chrome tracing colors
112-
# url: https://chromium.googlesource.com/external/trace-viewer/+/bf55211014397cf0ebcd9e7090de1c4f84fc3ac0/tracing/tracing/ui/base/color_scheme.html
113-
114-
# # ------
167+
@lru_cache(maxsize=1)
168+
def _cached_env_info() -> dict:
169+
dist_env = _DistributedEnv.detect()
170+
worker_env = _WorkerEnv.detect()
171+
return {
172+
"dist_world_size": dist_env.world_size,
173+
"dist_global_rank": dist_env.global_rank,
174+
"dist_num_nodes": dist_env.num_nodes,
175+
"worker_world_size": worker_env.world_size,
176+
"worker_rank": worker_env.rank,
177+
}
115178

116179

117-
# thread_state_iowait: {r: 182, g: 125, b: 143},
118-
# thread_state_running: {r: 126, g: 200, b: 148},
119-
# thread_state_runnable: {r: 133, g: 160, b: 210},
120-
# ....
180+
# Chrome trace colors
121181
class ChromeTraceColors:
122182
PINK = "thread_state_iowait"
123183
GREEN = "thread_state_running"

src/litdata/streaming/compression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def compress(self, data: bytes) -> bytes:
6262
def decompress(self, data: bytes) -> bytes:
6363
import zstd
6464

65-
logger.debug(_get_log_msg({"name": "Decompressing data", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
65+
logger.debug(_get_log_msg({"name": "decompress", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
6666
decompressed_data = zstd.decompress(data)
67-
logger.debug(_get_log_msg({"name": "Decompressed data", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
67+
logger.debug(_get_log_msg({"name": "decompress", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
6868
return decompressed_data
6969

7070
@classmethod

0 commit comments

Comments
 (0)