diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3ae063d..4b95f93 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -8,14 +8,26 @@ on: branches: - main +# These permissions are needed to interact with AWS S3 via GitHub's OIDC Token endpoint +permissions: + id-token: write + contents: read + pull-requests: read + jobs: unit-tests: runs-on: ${{ matrix.os }} strategy: fail-fast: true matrix: + # The full test-suite is only run with os=ubuntu and py=3.12 python-version: [ - '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' + '3.8', + '3.9', + '3.10', + '3.11', + '3.12', + '3.13' ] os: [ubuntu-latest] EXTRA: [false] # used to force includes to get included @@ -23,11 +35,12 @@ jobs: - python-version: '3.8' os: ubuntu-22.04 # oldest version on github actions EXTRA: true - - python-version: '3.13' - os: ubuntu-latest - env: - LOGLEVEL=DEBUG - EXTRA: true + # disabled (duplicated matrix entry) + # - python-version: '3.13' + # os: ubuntu-latest + # env: + # LOGLEVEL=DEBUG + # EXTRA: true - python-version: '3.13' os: macos-latest EXTRA: true @@ -57,9 +70,43 @@ jobs: - name: Install cdx_toolkit run: pip install .[test] + - name: Configure AWS credentials from OIDC (disabled for forks) + if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'push' + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::837454214164:role/GitHubActions-Role + aws-region: us-east-1 + + - name: Disable S3 unit tests for Python 3.8 (boto3 requires Python 3.9+) + if: ${{ startsWith(matrix.python-version, '3.8') }} + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('CDXT_DISABLE_S3_TESTS', '1') + + - name: Set environment variables for faster unit tests (requests are mocked) + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('CDXT_MAX_ERRORS', '2') + core.exportVariable('CDXT_WARNING_AFTER_N_ERRORS', '2') + core.exportVariable('CDXT_DEFAULT_MIN_RETRY_INTERVAL', '0.01') + core.exportVariable('CDXT_CC_INDEX_MIN_RETRY_INTERVAL', '0.01') + core.exportVariable('CDXT_CC_DATA_MIN_RETRY_INTERVAL', '0.01') + core.exportVariable('CDXT_IA_MIN_RETRY_INTERVAL', '0.01') + core.exportVariable('LOGLEVEL', 'DEBUG') + + # - name: Run tests (only feature) + # run: | + # # make test_coverage + # pytest -rA -s --doctest-modules --cov-report=xml --cov-append --cov cdx_toolkit tests/filter_warc tests/filter_cdx -v -v + # coverage report + - name: Run tests run: | - make test_coverage + # make test_coverage + pytest -rA -s --doctest-modules --cov-report=xml --cov-append --cov cdx_toolkit tests/ -v -v + coverage report - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 diff --git a/README.md b/README.md index 9d794ff..92e06b6 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,112 @@ get the most recent N captures: --limit and limit= will return the oldest N captures. With the 'mixed' ordering, a large enough limit= will get close to returning the most recent N captures. +## Filtering CDX files + +The command line cdxt can be used to filter CDX files based on a given +whitelist of URLs or SURTs. In particular, the filtering process +extracts all CDX entries that match with at least one entry in the +whitelist. All other CDX entries are discarded. + +For matching, all URLs are converted into SURTs. A match occurs +when a given SURT from the CDX file starts with one of the prefixes +defined in the SURTS of whitelist. + +The CDX filter can read and write files from local and remote file +systems, like S3 buckets. Multiple input files can be defined +using a glob pattern. + +``` +$ cdx filter_cdx \ + --filter-type \ + [--input-glob [--cdx-glob ] \ + --prefix \ + --warc-download-prefix= \ + --creator \ + --operator \ + [--implementation ] + [--write-paths-as-resource-records ] + [--write-paths-as-resource-records-metadata ] +``` + +By default, we use a [fsspec](https://filesystem-spec.readthedocs.io/en/latest/index.html) +implementation to write and read to local or remote file systems. +For better throughput for S3 read/write, we have also a specific implementation +using [aioboto3](https://github.com/terricain/aioboto3) that you can enable with +the `--implementation=aioboto3` argument. With aioboto3, we achieved ~ 80 requests / second +on an AWS EC2 c5n.xlarge instance. + +You can add one or multiple files with metadata as resource records to +the extracted WARC. For instance, this is useful to maintain the CDX filter +inputs, e.g., the whitelist list. To do this, you need to provide the +corresponding file paths as arguments `--write-paths-as-resource-records=s3:///my-s3-bucket/path/to/my-url-whitelist.txt` +and `--write-paths-as-resource-records-metadata=s3:///my-s3-bucket/path/to/metadata.json`. +The metadata file is optional and can have the following optional fields: + +```json +{ + "warc_content_type": "str", + "uri": "str", + "http_headers": {"k": "v"}, + "warc_headers_dict": {"k": "v"} +} +``` + +This in one example for a metadata JSON file: + +```json +{ + "uri": "filter_cdx.gz", + "warc_content_type": "application/cdx", +} +``` + +The full WARC extraction command could look like this: + +``` +$ cdxt -v --cc warc_by_cdx \ + s3://my-s3-bucket/filtered-cdxs --cdx-glob "*.gz" \ + --prefix /local/path/filtered-warcs/ \ + --warc-download-prefix=s3://commoncrawl \ + --creator foo --operator bob \ + --write-paths-as-resource-records=s3:///my-s3-bucket/path/to/my-url-whitelist.txt \ + --write-paths-as-resource-records-metadata=s3:///my-s3-bucket/path/to/metadata.json +``` + ## TODO Content downloading needs help with charset issues, preferably diff --git a/cdx_toolkit/__init__.py b/cdx_toolkit/__init__.py index ea716c7..417323c 100644 --- a/cdx_toolkit/__init__.py +++ b/cdx_toolkit/__init__.py @@ -1,7 +1,6 @@ import logging import json from collections.abc import MutableMapping -import sys import warnings try: @@ -37,7 +36,7 @@ def showNumPages(r): elif isinstance(j, int): # ia always returns text, parsed as a json int pages = j else: - raise ValueError('surprised by showNumPages value of '+str(j)) + raise ValueError('surprised by showNumPages value of ' + str(j)) return pages @@ -80,18 +79,19 @@ def cdx_to_captures(resp, wb=None, warc_download_prefix=None): lines = json.loads(text) fields = lines.pop(0) except (json.decoder.JSONDecodeError, KeyError, IndexError): # pragma: no cover - raise ValueError('cannot decode response, first bytes are '+repr(text[:50])) + raise ValueError('cannot decode response, first bytes are ' + repr(text[:50])) ret = munge_fields(fields, lines) return [CaptureObject(r, wb=wb, warc_download_prefix=warc_download_prefix) for r in ret] - raise ValueError('cannot decode response, first bytes are '+repr(text[:50])) # pragma: no cover + raise ValueError('cannot decode response, first bytes are ' + repr(text[:50])) # pragma: no cover class CaptureObject(MutableMapping): - ''' + """ Represents a single capture of a webpage, plus less-visible info about how to fetch the content. - ''' + """ + def __init__(self, data, wb=None, warc_download_prefix=None): self.data = data self.wb = wb @@ -129,9 +129,9 @@ def content(self): @property def text(self): - ''' + """ Eventually this function will do something with the character set, but not yet. - ''' + """ return self.content.decode('utf-8', errors='replace') # the remaining code treats self.data like a dict @@ -176,8 +176,9 @@ def get_more(self): if self.page == 0 and len(self.index_list) > 0 and self.endpoint < len(self.index_list): LOGGER.info('get_more: fetching cdx from %s', self.index_list[self.endpoint]) - status, objs = self.cdxfetcher.get_for_iter(self.endpoint, self.page, - params=self.params, index_list=self.index_list) + status, objs = self.cdxfetcher.get_for_iter( + self.endpoint, self.page, params=self.params, index_list=self.index_list + ) if status == 'last endpoint': LOGGER.debug('get_more: I have reached the end') return # caller will raise StopIteration @@ -207,7 +208,16 @@ def __next__(self): class CDXFetcher: - def __init__(self, source='cc', crawl=None, wb=None, warc_download_prefix=None, cc_mirror=None, cc_sort='mixed', loglevel=None): + def __init__( + self, + source='cc', + crawl=None, + wb=None, + warc_download_prefix=None, + cc_mirror=None, + cc_sort='mixed', + loglevel=None, + ): self.source = source self.crawl = crawl self.cc_sort = cc_sort @@ -236,7 +246,14 @@ def __init__(self, source='cc', crawl=None, wb=None, warc_download_prefix=None, LOGGER.setLevel(level=loglevel) def customize_index_list(self, params): - if self.source == 'cc' and (self.crawl or 'crawl' in params or 'from' in params or 'from_ts' in params or 'to' in params or 'closest' in params): + if self.source == 'cc' and ( + self.crawl + or 'crawl' in params + or 'from' in params + or 'from_ts' in params + or 'to' in params + or 'closest' in params + ): LOGGER.info('making a custom cc index list') if self.crawl and 'crawl' not in params: params['crawl'] = self.crawl @@ -269,7 +286,9 @@ def get(self, url, **kwargs): ret = [] for endpoint in index_list: resp = myrequests_get(endpoint, params=params, cdx=True) - objs = cdx_to_captures(resp, wb=self.wb, warc_download_prefix=self.warc_download_prefix) # turns 400 and 404 into [] + objs = cdx_to_captures( + resp, wb=self.wb, warc_download_prefix=self.warc_download_prefix + ) # turns 400 and 404 into [] ret.extend(objs) if 'limit' in params: params['limit'] -= len(objs) @@ -297,15 +316,14 @@ def iter(self, url, **kwargs): def items(self, url, **kwargs): # pragma: no cover warnings.warn( - 'cdx.items() has been renamed to cdx.iter() and will be removed in cdx_toolkit 1.0', - FutureWarning + 'cdx.items() has been renamed to cdx.iter() and will be removed in cdx_toolkit 1.0', FutureWarning ) return self.iter(url, **kwargs) def get_for_iter(self, endpoint, page, params={}, index_list=None): - ''' + """ Specalized get for the iterator - ''' + """ if endpoint >= len(index_list): return 'last endpoint', [] if params.get('limit', -1) == 0: @@ -325,12 +343,12 @@ def get_for_iter(self, endpoint, page, params={}, index_list=None): return 'ok', ret def get_size_estimate(self, url, as_pages=False, **kwargs): - ''' + """ Get the number of pages that match url useful additional args: matchType='host' pageSize=1 or, url can end with * or start with *. to set the matchType - ''' + """ if 'details' in kwargs: details = True del kwargs['details'] diff --git a/cdx_toolkit/cli.py b/cdx_toolkit/cli.py index 6ffa393..dc266ce 100644 --- a/cdx_toolkit/cli.py +++ b/cdx_toolkit/cli.py @@ -6,7 +6,15 @@ import os import cdx_toolkit -from cdx_toolkit.commoncrawl import normalize_crawl + +from cdx_toolkit.utils import get_version, setup + +from cdx_toolkit.filter_cdx.command import run_filter_cdx +from cdx_toolkit.filter_cdx.args import add_filter_cdx_args + +from cdx_toolkit.filter_warc.command import run_warcer_by_cdx +from cdx_toolkit.filter_warc.args import add_warcer_by_cdx_args + LOGGER = logging.getLogger(__name__) @@ -54,6 +62,14 @@ def main(args=None): warc.add_argument('url') warc.set_defaults(func=warcer) + warc_by_cdx = subparsers.add_parser('warc_by_cdx', help='iterate over capture content based on an CDX index file, creating a warc') + add_warcer_by_cdx_args(warc_by_cdx) + warc_by_cdx.set_defaults(func=run_warcer_by_cdx) + + filter_cdx = subparsers.add_parser('filter_cdx', help='Filter CDX files based on SURT prefixes whitelist') + add_filter_cdx_args(filter_cdx) + filter_cdx.set_defaults(func=run_filter_cdx) + size = subparsers.add_parser('size', help='imprecise count of how many results are available') size.add_argument('--details', action='store_true', help='show details of each subindex') size.add_argument('url') @@ -89,48 +105,6 @@ def set_loglevel(cmd): LOGGER.info('set loglevel to %s', str(loglevel)) -def get_version(): - return cdx_toolkit.__version__ - - -def setup(cmd): - kwargs = {} - kwargs['source'] = 'cc' if cmd.crawl else cmd.cc or cmd.ia or cmd.source or None - if kwargs['source'] is None: - raise ValueError('must specify --cc, --ia, or a --source') - if cmd.wb: - kwargs['wb'] = cmd.wb - if cmd.cc_mirror: - kwargs['cc_mirror'] = cmd.cc_mirror - if cmd.crawl: - kwargs['crawl'] = normalize_crawl([cmd.crawl]) # currently a string, not a list - if getattr(cmd, 'warc_download_prefix', None) is not None: - kwargs['warc_download_prefix'] = cmd.warc_download_prefix - - cdx = cdx_toolkit.CDXFetcher(**kwargs) - - kwargs = {} - if cmd.limit: - kwargs['limit'] = cmd.limit - if 'from' in vars(cmd) and vars(cmd)['from']: # python, uh, from is a reserved word - kwargs['from_ts'] = vars(cmd)['from'] - if cmd.to: - kwargs['to'] = cmd.to - if cmd.closest: - if not cmd.get: # pragma: no cover - LOGGER.info('note: --closest works best with --get') - kwargs['closest'] = cmd.closest - if cmd.filter: - kwargs['filter'] = cmd.filter - - if cmd.cmd == 'warc' and cmd.size: - kwargs['size'] = cmd.size - - if cmd.cmd == 'size' and cmd.details: - kwargs['details'] = cmd.details - - return cdx, kwargs - def winnow_fields(cmd, fields, obj): if cmd.all_fields: @@ -213,9 +187,15 @@ def warcer(cmd, cmdline): LOGGER.warning('revisit record being resolved for url %s %s', url, timestamp) writer.write_record(record) + writer.close() + def sizer(cmd, cmdline): cdx, kwargs = setup(cmd) size = cdx.get_size_estimate(cmd.url, **kwargs) print(size) + + +if __name__ == "__main__": + main() diff --git a/cdx_toolkit/commoncrawl.py b/cdx_toolkit/commoncrawl.py index 6834217..7d75690 100644 --- a/cdx_toolkit/commoncrawl.py +++ b/cdx_toolkit/commoncrawl.py @@ -11,6 +11,7 @@ from .myrequests import myrequests_get from .timeutils import time_to_timestamp, timestamp_to_time, pad_timestamp, pad_timestamp_up, cc_index_to_time, cc_index_to_time_special +from .settings import get_mock_time LOGGER = logging.getLogger(__name__) @@ -70,7 +71,7 @@ def get_cc_endpoints(cc_mirror): url = cc_mirror.rstrip('/') + '/collinfo.json' r = myrequests_get(url) if r.status_code != 200: - raise RuntimeError('error {} getting list of cc indices from {}'.format(r.status_code, collinfo)) # pragma: no cover + raise RuntimeError('error {} getting list of cc indices from {}'.format(r.status_code, url)) # pragma: no cover set_collinfo_cache(cc_mirror, r.text) col = r.json() @@ -119,9 +120,13 @@ def apply_cc_defaults(params, crawl_present=False, now=None): LOGGER.info('to but no from_ts, setting from_ts=%s', params['from_ts']) else: if not now: - # now is passed in by tests. if not set, use actual now. - # XXX could be changed to mock - now = time.time() + # Check for test/override time first + mock_time = get_mock_time() + if mock_time: + now = mock_time + else: + # now is passed in by tests. if not set, use actual now. + now = time.time() params['from_ts'] = time_to_timestamp(now - year) LOGGER.info('no from or to, setting default 1 year ago from_ts=%s', params['from_ts']) else: diff --git a/cdx_toolkit/filter_cdx/__init__.py b/cdx_toolkit/filter_cdx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cdx_toolkit/filter_cdx/args.py b/cdx_toolkit/filter_cdx/args.py new file mode 100644 index 0000000..03d6f9e --- /dev/null +++ b/cdx_toolkit/filter_cdx/args.py @@ -0,0 +1,40 @@ +import argparse + + +def add_filter_cdx_args(parser: argparse.ArgumentParser): + """Add command line arguments.""" + parser.add_argument( + 'input_base_path', + help='Base directory path on the local file system or remote URL for one or multiple CDX files (e.g., URL to S3 bucket)', + ) + parser.add_argument( + 'filter_file', + help='Path to file containing URL or SURT prefixes to filter for (one per line)', + ) + parser.add_argument( + 'output_base_path', + help='Base directory path for output files (directory structure will be replicated from input_base_path)', + ) + parser.add_argument( + '--filter-type', + type=str, + default='url', + help='Type of filter entries (options: `url` or `surt`, defaults to `url`)', + ) + parser.add_argument( + '--input-glob', + help="Glob pattern relative to input_base_path (e.g., '**/*.cdx.gz' or 'collections/*/indexes/*.gz')", + ) + parser.add_argument( + '--overwrite', + action='store_true', + help='Allow overwriting existing output files', + ) + parser.add_argument( + '--parallel', + type=int, + default=1, + help='Number of parallel workers for processing multiple input files (default: 1, sequential processing)', + ) + + return parser diff --git a/cdx_toolkit/filter_cdx/cdx_filter.py b/cdx_toolkit/filter_cdx/cdx_filter.py new file mode 100644 index 0000000..245b87b --- /dev/null +++ b/cdx_toolkit/filter_cdx/cdx_filter.py @@ -0,0 +1,134 @@ +import logging + +from url_is_in import SURTMatcher +import fsspec + +from multiprocessing import Pool +from typing import List, Tuple + + +logger = logging.getLogger(__name__) + + +def _filter_single_cdx_file( + input_path: str, + output_path: str, + matcher: SURTMatcher, + limit: int = 0, + log_every_n: int = 100_000, +) -> Tuple[str, str, int, int, int]: + """Process a single input/output file pair. Returns (lines_n, included_n).""" + lines_n = 0 + included_n = 0 + errors_n = 0 + + logger.info('Reading index from %s', input_path) + logger.info('Writing filter output to %s', output_path) + + try: + + # Input/output from local or remote file system + input_fs, input_fs_path = fsspec.url_to_fs(input_path) + output_fs, output_fs_path = fsspec.url_to_fs(output_path) + + # Make sure output directory exists + output_fs.makedirs(output_fs._parent(output_fs_path), exist_ok=True) + + # Read and write compressed file if needed + compression = 'gzip' if input_fs_path.endswith('.gz') else None + + with output_fs.open(output_fs_path, 'w', compression=compression) as output_f: + with input_fs.open(input_fs_path, 'rt', compression=compression) as input_f: + for i, line in enumerate(input_f, 1): + try: + # Read CDX line + surt_length = line.find(' ') # we do not need to parse the full line + record_surt = line[:surt_length] + lines_n += 1 + + # Use SURT matcher + include_record = matcher.is_in(record_surt) + + if include_record: + output_f.write(line) + included_n += 1 + + if limit > 0 and included_n >= limit: + logger.info('Limit reached at %i from %s', limit, input_path) + break + + if (i % log_every_n) == 0: + logger.info(f'Lines completed: {i:,} (matched: {included_n:,}) from {input_path}') + + except Exception as e: + logger.error(f"Line processing error: {e}") + errors_n += 1 + + # Delete file if empty + if included_n == 0: + logger.warning('Output file is empty, removing it: %s', output_fs_path) + output_fs.rm(output_fs_path) + + except Exception as e: + logger.error(f"File processing error: {e}") + errors_n += 1 + + return input_path, output_path, lines_n, included_n, errors_n + + +def _filter_single_cdx_file_args(kwargs: dict) -> Tuple[str, str, int, int, int]: + """Wrapper function to unpack arguments for multiprocessing.""" + + return _filter_single_cdx_file(**kwargs) + + +def filter_cdx( + matcher: SURTMatcher, + input_paths: List[str], + output_paths: List[str], + n_parallel: int = 1, + limit: int = 0, + total_lines_n: int = 0, + total_included_n: int = 0, + total_errors_n: int = 0, + log_every_n: int = 100_000, +) -> Tuple[int, int, int]: + """Filter CDX files from input paths using a matcher to output paths.""" + + # Parallel processing + logger.info('Filtering with %i processes in parallel (limit: %i)', n_parallel, limit) + + # Prepare arguments for each task (input_path, output_path, matcher, limit) + task_args = [dict( + input_path=input_path, + output_path=output_path, matcher=matcher, limit=limit, log_every_n=log_every_n) + for input_path, output_path in zip(input_paths, output_paths)] + + pool = None + try: + pool = Pool(processes=n_parallel) + # Use imap for better interrupt handling + for input_path, _, lines_n, included_n, errors_n in pool.imap(_filter_single_cdx_file_args, task_args): + total_lines_n += lines_n + total_included_n += included_n + total_errors_n += errors_n + + logger.info(f'File statistics: included {total_included_n} / {total_lines_n} lines: {input_path}') + + except KeyboardInterrupt: + logger.warning('Process interrupted by user (Ctrl+C). Terminating running tasks...') + if pool: + pool.terminate() + pool.join() + logger.info('All tasks terminated.') + except Exception as exc: + logger.error(f'Error during parallel processing: {exc}') + total_errors_n += 1 + finally: + if pool: + pool.close() + pool.join() + + logger.warning(f"Filter CDX errors: {total_errors_n}") + + return total_lines_n, total_included_n, total_errors_n \ No newline at end of file diff --git a/cdx_toolkit/filter_cdx/command.py b/cdx_toolkit/filter_cdx/command.py new file mode 100644 index 0000000..0ca10ce --- /dev/null +++ b/cdx_toolkit/filter_cdx/command.py @@ -0,0 +1,83 @@ +from cdx_toolkit.filter_cdx.path_utils import validate_resolved_paths + +import logging +import fsspec +from url_is_in import SURTMatcher, convert_url_to_surt_with_wildcard + + +import sys +import time + +from cdx_toolkit.filter_cdx.cdx_filter import filter_cdx +from cdx_toolkit.filter_cdx.path_utils import resolve_paths + +logger = logging.getLogger(__name__) + + +def run_filter_cdx(args, cmdline: str): + """Filter CDX index files based on a given URL or SURT whitelist. + + - If a URL filter is provided, it is converted to a SURT filter. + - A index entry's SURT must start with one of the SURTs from the whitelist to be considered. + - All other index entries are discarded. + - All input/output paths can be local or remote paths (S3, ...) and compressed (*.gz). + """ + logger.info('Filtering CDX files based on whitelist') + + # Start timing + start_time = time.time() + + # Resolve input and output paths using glob pattern + # This should support glob via S3 (e.g., to fetch the indices from s3://commoncrawl/cc-index/collections/* ...) + input_paths, output_paths = resolve_paths( + input_base_path=args.input_base_path, + input_glob=args.input_glob, + output_base_path=args.output_base_path, + ) + validate_resolved_paths(output_paths, args.overwrite) + + logger.info(f'Found {len(input_paths)} files matching pattern: {args.input_base_path}/{args.input_glob}') + + # Load URL or SURT prefixes from file (each line is a surt) + filter_fs, filter_fs_path = fsspec.url_to_fs(args.filter_file) + logger.info('Loading whitelist from %s', filter_fs_path) + + if not filter_fs.exists(filter_fs_path): # Check that surts file exists + logger.error(f'Filter file not found: {filter_fs_path}') + sys.exit(1) + + with filter_fs.open(filter_fs_path, 'rt') as input_f: + include_prefixes = [line.strip() for line in input_f.readlines()] + + logger.info(f'Loaded {len(include_prefixes):,} filter entries') + + # Convert URLs to SURTs + if args.filter_type == 'url': + include_prefixes = [convert_url_to_surt_with_wildcard(item_url) for item_url in include_prefixes] + + matcher = SURTMatcher(include_prefixes, match_subdomains=True) + + limit = 0 if args.limit is None else args.limit + + # Process files in parallel + total_lines_n, total_included_n, total_errors_n = filter_cdx( + matcher=matcher, + input_paths=input_paths, + output_paths=output_paths, + limit=limit, + n_parallel=max(1, args.parallel), + ) + + # Calculate ratio safely to avoid division by zero + ratio = total_included_n / total_lines_n if total_lines_n > 0 else 0.0 + logger.info(f'Filter statistics: {total_included_n} / {total_lines_n} lines ({ratio:.4f})') + logger.info(f'Errors: {total_errors_n}') + + if limit > 0 and total_included_n >= 0: + logger.info(f'Limit reached at {limit}') + + # End timing and log execution time + end_time = time.time() + execution_time = end_time - start_time + + logger.info(f'Script execution time: {execution_time:.3f} seconds') \ No newline at end of file diff --git a/cdx_toolkit/filter_cdx/path_utils.py b/cdx_toolkit/filter_cdx/path_utils.py new file mode 100644 index 0000000..08237b9 --- /dev/null +++ b/cdx_toolkit/filter_cdx/path_utils.py @@ -0,0 +1,68 @@ +import logging +from typing import Optional + +import fsspec + + +import os +import sys + + +logger = logging.getLogger(__name__) + + +def resolve_paths(input_base_path: str, input_glob: Optional[str], output_base_path: str): + """Resolve input paths from glob pattern and generate corresponding output paths.""" + # Use fsspec to handle local and remote file systems + input_fs, input_fs_base_path = fsspec.url_to_fs(input_base_path) + + if input_glob is None: + # No glob pattern + input_fs_file_paths = [input_fs_base_path] + else: + input_full_glob = input_fs_base_path + input_glob + + # Get input files from glob pattern + input_fs_file_paths = sorted(input_fs.glob(input_full_glob)) + if not input_fs_file_paths: + logger.error(f'No files found matching glob pattern: {input_full_glob}') + sys.exit(1) + + # Generate corresponding output paths + output_file_paths = [] + input_file_paths = [] + for input_path in input_fs_file_paths: + # Get relative path from input_base_path without last slash + rel_path = input_path[len(input_fs_base_path) + 1 :] + + # Create corresponding full input and output path + # Use forward slashes for URL paths (S3, HTTP, etc.) to ensure cross-platform compatibility + if '://' in output_base_path: + output_file_paths.append(output_base_path + '/' + rel_path) + else: + # Normalize path separators for local filesystem + normalized_rel_path = rel_path.replace('/', os.sep) + output_file_paths.append(os.path.join(output_base_path, normalized_rel_path)) + + if '://' in input_base_path: + input_file_paths.append(input_base_path + '/' + rel_path) + else: + # Normalize path separators for local filesystem + normalized_rel_path = rel_path.replace('/', os.sep) + input_file_paths.append(os.path.join(input_base_path, normalized_rel_path)) + + return input_file_paths, output_file_paths + + +def validate_resolved_paths(output_paths, overwrite): + """Validate resolved output paths and create directories if needed.""" + # Check if output files exist and overwrite flag + if not overwrite: + output_fs, _ = fsspec.url_to_fs(output_paths[0]) + for output_path in output_paths: + if output_fs.exists(output_path): + logger.error(f'Output file already exists: {output_path}. Use --overwrite to overwrite existing files.') + sys.exit(1) + + # Make sure directory exists + output_fs.makedirs(output_fs._parent(output_path), exist_ok=True) diff --git a/cdx_toolkit/filter_warc/__init__.py b/cdx_toolkit/filter_warc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py new file mode 100644 index 0000000..c1d87f5 --- /dev/null +++ b/cdx_toolkit/filter_warc/args.py @@ -0,0 +1,115 @@ +import logging +import argparse + + +logger = logging.getLogger(__name__) + + +def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): + parser.add_argument( + '--cdx-path', + type=str, + default=None, + help='Path to CDX index file (local or remote, e.g. S3). Required if target source is set to `cdx`.', + ) + parser.add_argument( + '--cdx-glob', + type=str, + default=None, + help='a glob pattern for read from multiple CDX indices', + ) + parser.add_argument( + '--athena-hostnames', + type=str, + nargs="+", + default=None, + help='Hostnames to filter for via Athena (whitelist). Required if target source is set to `athena`.', + ) + parser.add_argument( + '--athena-database', + type=str, + default=None, + help='Athena database. Required if target source is set to `athena`.', + ) + parser.add_argument( + '--athena-s3-output', + type=str, + default=None, + help='Athena S3 output location. Required if target source is set to `athena`.', + ) + parser.add_argument('--prefix', default='TEST', help='prefix for the output warc filename') + parser.add_argument( + '--subprefix', + type=str, + default=None, + help='subprefix for the warc filename, default None', + ) + parser.add_argument( + '--size', + type=int, + default=1000000000, + help='target for the warc filesize in bytes', + ) + parser.add_argument( + '--creator', + action='store', + help='creator of the warc: person, organization, service', + ) + parser.add_argument('--operator', action='store', help='a person, if the creator is an organization') + parser.add_argument( + '--description', + action='store', + help='the `description` field in the `warcinfo` record (auto-generated if not set)', + ) + parser.add_argument( + '--is-part-of', + action='store', + help='the `isPartOf` field in the `warcinfo` record (auto-generated if not set)', + ) + parser.add_argument( + '--warc-download-prefix', + action='store', + help='prefix for downloading content, automatically set for CC', + default='https://data.commoncrawl.org', + ) + parser.add_argument( + '--write-paths-as-resource-records', # --write-index-as-record + nargs='*', + help='Paths to multiple files. File content is written to as a resource record to each the WARC file', + ) + parser.add_argument( + '--write-paths-as-resource-records-metadata', + nargs='*', + help='Paths to multiple metadata files (JSON) for resource records from `--write-paths-as-resource-records`', + ) + parser.add_argument( + '--parallel', + type=int, + default=1, + help='Number of parallel workers for reading and writing WARC records (default: 1, sequential processing)', + ) + parser.add_argument( + '--parallel_readers', + type=int, + default=None, + help='Number of parallel workers for reading WARC records (default: same as `parallel`)', + ) + parser.add_argument( + '--parallel_writers', + type=int, + default=None, + help='Number of parallel workers for writing WARC records (default: same as `parallel`)', + ) + parser.add_argument( + '--log_every_n', + type=int, + default=1000, + help='Every N extracted record a log message is emitted (0 = no record logs)', + ) + parser.add_argument( + '--target-source', + action='store', + default='cdx', + help='Source from that the filter targets are loaded (available options: `cdx`, `athena`; defaults to `cdx`)', + ) + return parser diff --git a/cdx_toolkit/filter_warc/athena_job_generator.py b/cdx_toolkit/filter_warc/athena_job_generator.py new file mode 100644 index 0000000..2911125 --- /dev/null +++ b/cdx_toolkit/filter_warc/athena_job_generator.py @@ -0,0 +1,139 @@ +import asyncio +import logging +import time +from typing import Any, Iterable, List + +from cdx_toolkit.filter_warc.data_classes import RangeJob + + +logger = logging.getLogger(__name__) + + +async def get_range_jobs_from_athena( + client, + database: str, + s3_output_location: str, + job_queue: asyncio.Queue, + queue_stop_object: Any, + url_host_names: List[str], + warc_download_prefix: str, + num_fetchers: int, + limit: int = 0, + max_wait_time: int = 300, +) -> int: + """Generate range job based on an Athena query. + + CommonCrawl provides an index via AWS Athena that we can use to + find the file names, offsets, and byte lengths for WARC filtering. + + See https://commoncrawl.org/blog/index-to-warc-files-and-urls-in-columnar-format + + This function queries Athena (AWS access required), puts a RangeJob + (WARC files and offets) for each result and pushed them to the + asyncio queue.""" + + logger.info('Range index limit: %i', limit) + count = 0 + + # Build query + tlds = set([url.split('.')[-1] for url in url_host_names]) # unique TLDs + query_tlds = ' OR '.join([f" url_host_tld = '{tld}'" for tld in tlds]) + query_hostnames = ' OR '.join([f" url_host_name = '{host_name}'" for host_name in url_host_names]) + + # TODO there should be more filteres (dates/crawls/...) + + query_limit = f'LIMIT {limit}' if limit > 0 else '' + + query = f""" + SELECT + warc_filename, warc_record_offset, warc_record_length + FROM ccindex + WHERE subset = 'warc' + AND ({query_tlds}) -- help the query optimizer + AND ({query_hostnames}) + {query_limit}""" + + logger.info('Executing Athena query...') + + # Start query execution + response = client.start_query_execution( + QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': s3_output_location}, + ) + + query_execution_id = response['QueryExecutionId'] + + logger.info(f'Query execution started. ID: {query_execution_id}') + status = _wait_for_query_completion(client, query_execution_id, max_wait_time) + + if status == 'SUCCEEDED': + for range_job in _get_query_results(client, query_execution_id, warc_download_prefix): + await job_queue.put(range_job) + count += 1 + else: + raise Exception(f'Query failed with status: {status}') + + # Signal fetchers to stop + for _ in range(num_fetchers): + await job_queue.put(queue_stop_object) + + logger.info('Athena query enqueued %d jobs', count) + + return count + + +def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: + """Wait for query to complete and return final status""" + start_time = time.time() + + while time.time() - start_time < max_wait_time: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + + status = response['QueryExecution']['Status']['State'] + logger.info(f'Query status: {status}') + + if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if status == 'FAILED': + error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') + logger.info(f'Query failed: {error_reason}') + return status + + time.sleep(2) + + raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') + + +def _get_query_results(client, query_execution_id: str, warc_download_prefix: str) -> Iterable[RangeJob]: + """Retrieve query results and convert to pandas DataFrame""" + # Get query results + paginator = client.get_paginator('get_query_results') + page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) + column_names = None + + for page in page_iterator: + rows = page['ResultSet']['Rows'] + + # Get column names from first page + if column_names is None and rows: + column_names = [col['VarCharValue'] for col in rows[0]['Data']] + rows = rows[1:] # Skip header row + + # Process data rows + for row in rows: + row_data = [] + for cell in row['Data']: + value = cell.get('VarCharValue', None) + row_data.append(value) + + row = dict(zip(column_names, row_data)) + + warc_url = warc_download_prefix + "/" + row['warc_filename'] + + yield RangeJob(url=warc_url, offset=int(row['warc_record_offset']), length=int(row['warc_record_length'])) + + +def get_databases(client) -> list: + """Get list of available databases""" + response = client.list_databases(CatalogName='AwsDataCatalog') + return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/cdx_utils.py b/cdx_toolkit/filter_warc/cdx_utils.py new file mode 100644 index 0000000..2a0c4ec --- /dev/null +++ b/cdx_toolkit/filter_warc/cdx_utils.py @@ -0,0 +1,93 @@ +import json +from pathlib import Path + +import sys +from typing import Iterable, List, Optional, Tuple, Union + +import fsspec +import logging + + +logger = logging.getLogger(__name__) + + +def get_index_as_string_from_path( + index_path: Union[str, Path], index_fs: Optional[fsspec.AbstractFileSystem] = None +) -> str: + """Fetch (and decompress) index content as string from local or remote path.""" + logger.info('Fetching index from %s ...', index_path) + if index_fs is None: + index_fs, index_fs_path = fsspec.url_to_fs(index_path) + else: + index_fs_path = index_path + + compression = 'gzip' if index_fs_path.endswith('.gz') else None + + with index_fs.open(index_fs_path, 'rt', compression=compression) as f: + return f.read() + + +def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: + cols = line.split(' ', maxsplit=2) + + if len(cols) == 3: + # NOTE: We assume the following format (CC-CDX format): + # + # IA follows a different CDX specification from https://iipc.github.io/warc-specifications/specifications/cdx-format/cdx-2015/ + # > The default first line of a CDX file is: + # > CDX A b e a m s c k r V v D d g M n + data = json.loads(cols[2]) + data['timestamp'] = cols[1] + else: + raise ValueError(f'Cannot parse line: {line}') + + filename = data['filename'] + offset = int(data['offset']) + length = int(data['length']) + + warc_url = warc_download_prefix + '/' + filename + + return (warc_url, offset, length) + + +def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int]]: + """ + Iterate CDX records from a file path (gzipped; local or remote). + """ + logger.info('Reading CDX from %s', index_path) + + with fsspec.open(index_path, 'rt', compression='gzip' if index_path.endswith('.gz') else None) as f: + for line in f: + try: + yield read_cdx_line(line, warc_download_prefix) + except Exception: + # Skip malformed lines + logger.error('Invalid CDX line: %s', line) + continue + + logger.info(f'CDX completed from {index_path}') + + +def get_cdx_paths(index_path: str, index_glob: Optional[str] = None) -> List[str]: + """Find CDX index paths using glob pattern.""" + if index_glob is None: + # Read from a single index + index_paths = [index_path] + else: + # Prepare index paths + index_fs, index_fs_path = fsspec.url_to_fs(index_path) + + # Fetch multiple indicies via glob + full_glob = index_fs_path + index_glob + + logger.info('glob pattern from %s (%s)', full_glob, index_fs.protocol) + + index_paths = sorted(index_fs.glob(full_glob)) + + logger.info('glob pattern found %i index files in %s', len(index_paths), index_fs_path) + + if not index_paths: + logger.error('no index files found via glob') + sys.exit(1) + + return index_paths diff --git a/cdx_toolkit/filter_warc/command.py b/cdx_toolkit/filter_warc/command.py new file mode 100644 index 0000000..614673f --- /dev/null +++ b/cdx_toolkit/filter_warc/command.py @@ -0,0 +1,113 @@ +from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths +from cdx_toolkit.filter_warc.warc_filter import WARCFilter +from cdx_toolkit.utils import get_version + + +import fsspec + + +import time +import logging + +logger = logging.getLogger(__name__) + + +def run_warcer_by_cdx(args, cmdline): + """Like warcer but fetches WARC records based on one or more CDX index files. + + The CDX files can be filtered using the `filter_cdx` commands based a given URL/SURT list. + + Approach: + - Iterate over one or more CDX files to extract capture object (file, offset, length) + - Fetch WARC record based on capture object + - Write to new WARC file with metadata including resource record with index. + - The CDX resource record is written to the WARC directly before for response records that matches to the CDX. + """ + logger.info('Filtering WARC files based on CDX') + + # Start timing + start_time = time.time() + + write_paths_as_resource_records = args.write_paths_as_resource_records + write_paths_as_resource_records_metadata = args.write_paths_as_resource_records_metadata + + if write_paths_as_resource_records and write_paths_as_resource_records_metadata: + if len(write_paths_as_resource_records) != len(write_paths_as_resource_records_metadata): + raise ValueError('Number of paths to resource records must be equal to metadata paths.') + + if not write_paths_as_resource_records and write_paths_as_resource_records_metadata: + raise ValueError('Metadata paths are set but resource records paths are missing.') + + if args.is_part_of: + ispartof = args.is_part_of + else: + ispartof = args.prefix + if args.subprefix: + ispartof += '-' + args.subprefix + + info = { + 'software': 'pypi_cdx_toolkit/' + get_version(), + 'isPartOf': ispartof, + 'description': args.description + if args.description + else 'warc extraction based on CDX generated with: ' + cmdline, + 'format': 'WARC file version 1.0', + } + if args.creator: + info['creator'] = args.creator + if args.operator: + info['operator'] = args.operator + + # writer_kwargs = {} + # if 'size' in kwargs: + # writer_kwargs['size'] = kwargs['size'] + # del kwargs['size'] + + n_parallel = args.parallel + log_every_n = args.log_every_n + limit = 0 if args.limit is None else args.limit + prefix_path = str(args.prefix) + prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) + + # make sure the base dir exists + prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) + + # target source handling + if args.target_source == 'cdx': + cdx_paths = get_cdx_paths( + args.cdx_path, + args.cdx_glob, + ) + elif args.target_source == "athena": + # no extra handling required + cdx_paths = None + else: + raise ValueError(f'Invalid target source specified: {args.target_source} (available: cdx, athena)') + + warc_filter = WARCFilter( + target_source=args.target_source, + cdx_paths=cdx_paths, + athena_database=args.athena_database, + athena_s3_output_location=args.athena_s3_output, + athena_hostnames=args.athena_hostnames, + prefix_path=prefix_path, + writer_info=info, + writer_subprefix=args.subprefix, + write_paths_as_resource_records=write_paths_as_resource_records, + write_paths_as_resource_records_metadata=write_paths_as_resource_records_metadata, + record_limit=limit, + log_every_n=log_every_n, + warc_download_prefix=args.warc_download_prefix, + n_parallel=n_parallel, + max_file_size=args.size, + # writer_kwargs=writer_kwargs, + ) + records_n = warc_filter.filter() + + logger.info('WARC records extracted: %i', records_n) + + # End timing and log execution time + end_time = time.time() + execution_time = end_time - start_time + + logger.info(f'Script execution time: {execution_time:.3f} seconds') diff --git a/cdx_toolkit/filter_warc/data_classes.py b/cdx_toolkit/filter_warc/data_classes.py new file mode 100644 index 0000000..8f36325 --- /dev/null +++ b/cdx_toolkit/filter_warc/data_classes.py @@ -0,0 +1,96 @@ +import time +from dataclasses import dataclass + +from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri, with_retries +from typing import Tuple + +from cdx_toolkit.myrequests import myrequests_get + + +@dataclass +class ThroughputTracker: + """Track throughput metrics for fetchers and consumers.""" + + start_time: float = 0.0 + total_bytes: int = 0 + total_requests: int = 0 + total_records: int = 0 + + def start(self): + self.start_time = time.time() + + def add(self, bytes_count: int = 0, records_count: int = 0, requests_count: int = 1): + self.total_bytes += bytes_count + self.total_requests += requests_count + self.total_records += records_count + + def get_stats(self) -> dict: + elapsed = time.time() - self.start_time + + return { + 'elapsed': elapsed, + 'total_bytes': self.total_bytes, + 'total_requests': self.total_requests, + 'total_records': self.total_records, + 'bytes_per_sec': self.total_bytes / elapsed if elapsed > 0 else 0, + 'mb_per_sec': (self.total_bytes / elapsed) / (1024 * 1024) if elapsed > 0 else 0, + 'requests_per_sec': self.total_requests / elapsed if elapsed > 0 else 0, + 'records_per_sec': self.total_records / elapsed if elapsed > 0 else 0, + } + + +@dataclass(frozen=True) +class RangeJob: + """Defines a S3 or HTTP range read request.""" + + url: str + offset: int + length: int + records_count: int = 1 + + def is_s3(self): + return is_s3_url(self.url) + + def get_s3_bucket_and_key(self) -> Tuple[str, str]: + if self.is_s3(): + return parse_s3_uri(self.url) + else: + raise ValueError('Cannot get bucket and key from a HTTP job') + + async def ranged_get_bytes( + self, + max_attempts: int, + base_backoff_seconds: float, + s3_client=None, + ) -> bytes: + """Ranged get request to S3 with retries and backoff or HTTP.""" + offset = self.offset + length = self.length + + end = offset + length - 1 # inclusive + + if self.is_s3(): + # read from S3 + bucket, key = self.get_s3_bucket_and_key() + resp = await with_retries( + lambda: s3_client.get_object(Bucket=bucket, Key=key, Range=f'bytes={offset}-{end}'), + op_name=f'ranged_get {bucket}/{key}[{offset}:{end}]', + max_attempts=max_attempts, + base_backoff_seconds=base_backoff_seconds, + ) + return await resp['Body'].read() + + else: + # read from HTTP + headers = {'Range': 'bytes={}-{}'.format(offset, end)} + + resp = myrequests_get(self.url, headers=headers) + return resp.content + + +@dataclass(frozen=True) +class RangePayload: + """Bytes output from S3 or HTTP range read.""" + + job: RangeJob + data: bytes diff --git a/cdx_toolkit/filter_warc/local_writer.py b/cdx_toolkit/filter_warc/local_writer.py new file mode 100644 index 0000000..e85b052 --- /dev/null +++ b/cdx_toolkit/filter_warc/local_writer.py @@ -0,0 +1,42 @@ +import aiofiles + + +class LocalFileWriter: + """Async writer for local file system using aiofiles.""" + + def __init__( + self, + file_path: str, + buffer_size: int = 8192, + mode: str = 'wb' + ): + self.file_path = file_path + self.buffer_size = buffer_size + self.mode = mode + self.file_handle = None + self.buffer = bytearray() + + async def start(self): + self.file_handle = await aiofiles.open(self.file_path, self.mode) + + async def write(self, data: bytes): + self.buffer.extend(data) + if len(self.buffer) >= self.buffer_size: + await self._flush() + + async def _flush(self): + if self.buffer and self.file_handle: + await self.file_handle.write(bytes(self.buffer)) + await self.file_handle.flush() + self.buffer.clear() + + async def close(self): + try: + if self.buffer: + await self._flush() + if self.file_handle: + await self.file_handle.close() + except Exception: + if self.file_handle: + await self.file_handle.close() + raise diff --git a/cdx_toolkit/filter_warc/s3_utils.py b/cdx_toolkit/filter_warc/s3_utils.py new file mode 100644 index 0000000..1a0cf7c --- /dev/null +++ b/cdx_toolkit/filter_warc/s3_utils.py @@ -0,0 +1,55 @@ +import asyncio +import logging +from typing import Tuple +from os import urandom + +from botocore.exceptions import ClientError, EndpointConnectionError + + +logger = logging.getLogger(__name__) + + +def parse_s3_uri(uri: str) -> Tuple[str, str]: + """Parse a S3 URI and return bucket and prefix.""" + if not uri.startswith('s3://'): + raise ValueError(f'Not an S3 URI: {uri}') + rest = uri[5:] + i = rest.find('/') + if i <= 0 or i == len(rest) - 1: + raise ValueError(f'Malformed S3 URI: {uri}') + return rest[:i], rest[i+1:] + + +def is_s3_url(url: str) -> bool: + return url.startswith('s3:/') + + +async def with_retries(coro_factory, *, op_name: str, max_attempts: int, base_backoff_seconds: float): + """Perform operation with retries and backoff.""" + last_exc = None + for attempt in range(1, max_attempts + 1): + try: + return await coro_factory() + except (TimeoutError, ClientError, EndpointConnectionError) as exc: + last_exc = exc + if attempt >= max_attempts: + logger.error('%s failed after %d attempts: %r', op_name, attempt, exc) + break + sleep_s = _backoff(attempt, base_backoff_seconds) + logger.warning( + '%s failed (attempt %d/%d) - retrying in %.2fs', + op_name, + attempt, + max_attempts, + sleep_s, + ) + await asyncio.sleep(sleep_s) + raise last_exc + + +def _backoff(attempt: int, base_backoff_seconds: float) -> float: + """Time to sleep based on number of attempts""" + base = base_backoff_seconds * (2 ** (attempt - 1)) + + # Add random jitter between 80-120% of base delay + return max(0.05, base * (0.8 + 0.4 * urandom(1)[0] / 255)) diff --git a/cdx_toolkit/filter_warc/s3_writer.py b/cdx_toolkit/filter_warc/s3_writer.py new file mode 100644 index 0000000..b67e646 --- /dev/null +++ b/cdx_toolkit/filter_warc/s3_writer.py @@ -0,0 +1,174 @@ +import logging +from typing import List, Dict, Optional + +from cdx_toolkit.filter_warc.s3_utils import ( + with_retries, +) + +logger = logging.getLogger(__name__) + + +async def mpu_create( + s3, + bucket: str, + key: str, + *, + max_attempts: int, + base_backoff_seconds: float, +): + """Create multi part upload to S3.""" + kwargs = {'Bucket': bucket, 'Key': key} + resp = await with_retries( + lambda: s3.create_multipart_upload(**kwargs), + op_name=f'create_multipart_upload {bucket}/{key}', + max_attempts=max_attempts, + base_backoff_seconds=base_backoff_seconds, + ) + return resp['UploadId'] + + +async def mpu_upload_part( + s3, + bucket: str, + key: str, + upload_id: str, + part_number: int, + body: bytes, + max_attempts: int, + base_backoff_seconds: float, +) -> str: + """Upload a part of a multi-part upload to S3.""" + resp = await with_retries( + lambda: s3.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=body, + ), + op_name=f'upload_part {bucket}/{key}#{part_number}', + max_attempts=max_attempts, + base_backoff_seconds=base_backoff_seconds, + ) + return resp['ETag'] + + +async def mpu_complete( + s3, + bucket: str, + key: str, + upload_id: str, + parts: List[Dict], + max_attempts: int, + base_backoff_seconds: float, +): + """Send complete for multi-part upload.""" + await with_retries( + lambda: s3.complete_multipart_upload( + Bucket=bucket, Key=key, UploadId=upload_id, MultipartUpload={'Parts': parts} + ), + op_name=f'complete_multipart_upload {bucket}/{key}', + max_attempts=max_attempts, + base_backoff_seconds=base_backoff_seconds, + ) + + +async def mpu_abort(s3, bucket: str, key: str, upload_id: str): + """Abort mult-part upload.""" + try: + await s3.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + except Exception: + logger.exception('Failed to abort MPU %s on %s/%s', upload_id, bucket, key) + + +class S3ShardWriter: + """Manages one MPU: buffers bytes, uploads >=5 MiB parts, completes on close.""" + + def __init__( + self, + s3_client, + shard_key: str, + dest_bucket: str, + content_type: Optional[str], + min_part_size: int, + max_attempts: int, + base_backoff_seconds: float, + ): + self.s3_client = s3_client + self.shard_key = shard_key + self.dest_bucket = dest_bucket + self.content_type = content_type + self.min_part_size = min_part_size + self.max_attempts = max_attempts + self.base_backoff_seconds = base_backoff_seconds + self.upload_id: Optional[str] = None + self.part_number = 1 + self.parts: List[Dict] = [] + self.buffer = bytearray() + + async def start(self): + self.upload_id = await mpu_create( + self.s3_client, + self.dest_bucket, + self.shard_key, + max_attempts=self.max_attempts, + base_backoff_seconds=self.base_backoff_seconds, + ) + logger.info('Started MPU for %s (UploadId=%s)', self.shard_key, self.upload_id) + + async def _flush_full_parts(self): + while len(self.buffer) >= self.min_part_size: + chunk = self.buffer[: self.min_part_size] + del self.buffer[: self.min_part_size] + etag = await mpu_upload_part( + self.s3_client, + self.dest_bucket, + self.shard_key, + self.upload_id, + self.part_number, + bytes(chunk), + self.max_attempts, + self.base_backoff_seconds, + ) + self.parts.append({'PartNumber': self.part_number, 'ETag': etag}) + self.part_number += 1 + + async def write(self, data: bytes): + self.buffer.extend(data) + await self._flush_full_parts() + + async def close(self): + try: + if self.buffer: + etag = await mpu_upload_part( + self.s3_client, + self.dest_bucket, + self.shard_key, + self.upload_id, + self.part_number, + bytes(self.buffer), + self.max_attempts, + self.base_backoff_seconds, + ) + self.parts.append({'PartNumber': self.part_number, 'ETag': etag}) + self.part_number += 1 + self.buffer.clear() + + if self.parts: + await mpu_complete( + self.s3_client, + self.dest_bucket, + self.shard_key, + self.upload_id, + self.parts, + self.max_attempts, + self.base_backoff_seconds, + ) + logger.info('Completed MPU for %s with %d parts.', self.shard_key, len(self.parts)) + except Exception: + logger.exception('Completing MPU failed for %s; attempting abort.', self.shard_key) + if self.upload_id: + await mpu_abort(self.s3_client, self.dest_bucket, self.shard_key, self.upload_id) + raise + + diff --git a/cdx_toolkit/filter_warc/warc_filter.py b/cdx_toolkit/filter_warc/warc_filter.py new file mode 100644 index 0000000..b5781e6 --- /dev/null +++ b/cdx_toolkit/filter_warc/warc_filter.py @@ -0,0 +1,718 @@ +import asyncio +import logging +import statistics +import sys +from typing import List, Literal, Optional, Dict + + +from botocore.config import Config + +from cdx_toolkit.filter_warc.athena_job_generator import get_range_jobs_from_athena +from cdx_toolkit.filter_warc.s3_utils import ( + is_s3_url, +) +from cdx_toolkit.filter_warc.data_classes import RangeJob, RangePayload, ThroughputTracker +from cdx_toolkit.filter_warc.warc_utils import create_new_writer_with_header +from cdx_toolkit.filter_warc.cdx_utils import ( + iter_cdx_index_from_path, +) +from cdx_toolkit.filter_warc.warc_utils import get_bytes_from_warc_record, get_resource_record_from_path + + +_STOP = object() + +logger = logging.getLogger(__name__) + +TargetSourceType = Literal['cdx', 'athena'] + + +class WARCFilter: + """Filter or extract specific records from WARC files based on CDX indexes. + + The WARC filter uses a three stage listner-producer-consumer pattern. + + Filter targets: + - CDX index files from local or remote file system, containing paths to WARC files and positions of target records. + + WARC reader: + - HTTP range reads + - S3 range reads + + WARC writer: + - Local file system + - S3 using multi-part uploads + """ + + def __init__( + self, + prefix_path: str, + writer_info: Dict, + target_source: TargetSourceType = 'cdx', + cdx_paths: Optional[List[str]] = None, + athena_database: Optional[str] = None, + athena_hostnames: Optional[List[str]] = None, + athena_s3_output_location: Optional[str] = None, + writer_subprefix: Optional[str] = None, + write_paths_as_resource_records: Optional[List[str]] = None, + write_paths_as_resource_records_metadata: Optional[List[str]] = None, + record_limit: int = 0, + log_every_n: int = 1000, + warc_download_prefix: Optional[str] = None, + n_parallel: int = 1, + n_parallel_readers: Optional[int] = None, + n_parallel_writers: Optional[int] = None, + max_attempts: int = 5, + base_backoff_seconds: float = 0.5, + # writer_kwargs: Optional[Dict] = None, + range_jobs_queue_size: int = 1000, + warc_records_queue_size: int = 200, + fetcher_to_consumer_ratio: int = 6, + aws_region_name: str = 'us-east-1', + warc_version: str = '1.0', + content_type: Optional[str] = None, + min_part_size: int = 5 * 1024 * 1024, # 5 MiB (for upload) + max_file_size: Optional[int] = 1 * 1024 * 1024 * 1024, # 1 GiB (for WARC outputs) + ): + """Initialize the WARC filter. + + Args: + target_source: Source of filter targets (Athena query or CDX files). + cdx_paths: List of paths to CDX index files. + athena_database: Database for Athena query. + athena_hostnames: Hostnames for Athena query. + athena_s3_output_location: S3 output location for Athena query. + prefix_path: Output path prefix for filtered WARC files. + writer_info: Dictionary containing writer metadata. + writer_subprefix: Optional subprefix for writer output paths. + write_paths_as_resource_records: Optional list of file paths to write as resource records. + write_paths_as_resource_records_metadata: Optional list of metadata paths for resource records. + record_limit: Maximum number of records to process (0 for unlimited). + log_every_n: Log progress every N records. + warc_download_prefix: Optional prefix to prepend to WARC URLs. + n_parallel: Number of parallel workers (default for readers/writers). + n_parallel_readers: Number of parallel reader tasks (overrides n_parallel). + n_parallel_writers: Number of parallel writer tasks (overrides n_parallel). + max_attempts: Maximum retry attempts for failed operations. + base_backoff_seconds: Base backoff time in seconds for retries. + writer_kwargs: Optional additional kwargs for writers. + range_jobs_queue_size: Maximum size of range jobs queue. + warc_records_queue_size: Maximum size of WARC records queue. + fetcher_to_consumer_ratio: Ratio of readers to writers for auto-scaling. + aws_region_name: AWS region name for S3 operations. + warc_version: WARC format version (e.g., '1.0' or '1.1'). + content_type: Optional content type for WARC output. + min_part_size: Minimum part byte size for multipart uploads (default: 5 MiB). + max_file_size: Maximum byte size for individual WARC files (default: 1 GiB). + """ + self.cdx_paths = cdx_paths + self.target_source: TargetSourceType = target_source + self.athena_database = athena_database + self.athena_s3_output_location = athena_s3_output_location + self.athena_hostnames = athena_hostnames + self.prefix_path = prefix_path + self.writer_info = writer_info + self.writer_subprefix = writer_subprefix + self.write_paths_as_resource_records = write_paths_as_resource_records + self.write_paths_as_resource_records_metadata = write_paths_as_resource_records_metadata + self.record_limit = record_limit + self.log_every_n = log_every_n + self.warc_download_prefix = warc_download_prefix + + # self.writer_kwargs = writer_kwargs + self.range_jobs_queue_size = range_jobs_queue_size + self.warc_records_queue_size = warc_records_queue_size + self.aws_region_name = aws_region_name + self.fetcher_to_consumer_ratio = fetcher_to_consumer_ratio + self.max_attempts = max_attempts + self.base_backoff_seconds = base_backoff_seconds + + self.n_parallel = n_parallel + self.num_readers = n_parallel_readers if n_parallel_readers is not None else n_parallel + self.num_writers = ( + n_parallel_writers + if n_parallel_writers is not None + else max(int(self.num_readers / self.fetcher_to_consumer_ratio), 1) + ) + + # self.gzip = self.cdx_paths[0].endswith('.gz') if self.cdx_paths else False + self.gzip = True + + self.warc_version = warc_version + self.content_type = content_type + self.min_part_size = min_part_size + self.max_file_size = max_file_size + + def filter(self) -> int: + """Perform the filtering process (calls async method via asyncio.run). + + Returns: + int: Number of records written, or -1 if interrupted. + """ + try: + return asyncio.run(self.filter_async()) + except KeyboardInterrupt: + logger.warning('Interrupted by user.') + + return -1 + + def needs_aws(self) -> bool: + """Returns true if AWS (S3/Athena) is needed at any stage. + + Returns: + bool: True if AWS client is needed for any operation. + """ + return ( + self.target_source == 'athena' # stage 1 + or (self.cdx_paths is not None and len(self.cdx_paths) > 0 and is_s3_url(self.cdx_paths[0])) # stage 1 + or is_s3_url(self.warc_download_prefix) # stage 3 + or is_s3_url(self.prefix_path) # stage 3 + ) + + def get_boto3_base_config(self) -> Dict: + """Get boto3 base configuration for AWS client. + + Returns: + Dict: Boto3 base configuration object with retry and timeout settings. + """ + # Calculate max connections based on parallelism + # Each reader + writer needs connections, plus some overhead for retries + # max_pool_connections = max(50, (self.num_readers + self.num_writers) * 2) + + return dict( + region_name=self.aws_region_name, + retries={ + 'max_attempts': max(2, self.max_attempts), + 'mode': 'adaptive', # Better than 'standard' for variable workloads + }, + ) + + async def get_aws_clients(self) -> Optional[Dict]: + """Return S3/Athena clients for job/read/write if needed. + + Returns: + Optional[aioboto3.Session.client]: S3/Athena client context manager if S3/Athena is needed, None otherwise. + + Raises: + SystemExit: If S3 is needed but Python version is < 3.9. + """ + if self.needs_aws(): + if sys.version_info.major < 3 or (sys.version_info.major >= 3 and sys.version_info.minor < 9): + logger.error('Reading and writing to S3 requires Python version >= 3.9') + sys.exit(1) + + import aioboto3 + import boto3 + + session = aioboto3.Session() + + # Lightweight config for CDX index reads + job_config = Config( + max_pool_connections=5, + read_timeout=60, + **self.get_boto3_base_config(), + ) + + if self.target_source == 'athena': + # Athena does not need an async client + job_client = boto3.client('athena', config=job_config) + else: + job_client = session.client('s3', config=job_config) + + # High-throughput config for range reads + read_config = Config( + max_pool_connections=self.num_readers * 3, + read_timeout=300, + tcp_keepalive=True, + **self.get_boto3_base_config(), + ) + + # Optimized config for multipart uploads + write_config = Config( + max_pool_connections=self.num_writers * 4, + read_timeout=120, + connect_timeout=10, + **self.get_boto3_base_config(), + ) + + return { + 'job': job_client, + 'read': session.client('s3', config=read_config), + 'write': session.client('s3', config=write_config), + } + else: + return None + + async def filter_async(self) -> int: + """Filter process using a three stage approach (job generator, warc reader, warc writer). + + Returns: + int: Number of records written. + """ + range_jobs_queue: asyncio.Queue = asyncio.Queue(maxsize=self.range_jobs_queue_size) + warc_records_queue: asyncio.Queue = asyncio.Queue(maxsize=self.warc_records_queue_size) + + if self.needs_aws(): + clients = await self.get_aws_clients() + + # Handle mixed async/sync clients - Athena client is sync, S3 clients are async + if self.target_source == 'athena': + job_aws_client = clients['job'] # Sync client, no context manager needed + async with clients['read'] as read_aws_client, clients['write'] as write_aws_client: + return await self._run_filter_pipeline( + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + job_aws_client=job_aws_client, + read_s3_client=read_aws_client, + write_s3_client=write_aws_client, + ) + else: + async with clients['job'] as job_aws_client, clients['read'] as read_aws_client, clients[ + 'write' + ] as write_aws_client: + return await self._run_filter_pipeline( + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + job_aws_client=job_aws_client, + read_s3_client=read_aws_client, + write_s3_client=write_aws_client, + ) + else: + return await self._run_filter_pipeline( + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + ) + + async def _run_filter_pipeline( + self, + range_jobs_queue: asyncio.Queue, + warc_records_queue: asyncio.Queue, + job_aws_client=None, + read_s3_client=None, + write_s3_client=None, + ) -> int: + """Run the actual filter pipeline with or without S3 client. + + Args: + range_jobs_queue: Queue for range jobs from CDX index. + warc_records_queue: Queue for WARC record payloads. + job_aws_client: Optional AWS (S3/Athena) client for jobs generation. + read_s3_client: Optional S3 client for reads from S3. + write_s3_client: Optional S3 client for writes S3. + + Returns: + int: Number of records written. + """ + # Fetch file paths and ranges (offset, length) from index files + logger.info('Starting job generator, %d WARC readers, %d WARC writers', self.num_readers, self.num_writers) + + # Generate range jobs from different target sources + if self.target_source == 'cdx': + job_generators = asyncio.create_task( + self.generate_range_jobs_from_cdx( + range_jobs_queue, + s3_client=job_aws_client, + ) + ) + elif self.target_source == 'athena': + job_generators = asyncio.create_task( + get_range_jobs_from_athena( + client=job_aws_client, + database=self.athena_database, + s3_output_location=self.athena_s3_output_location, + job_queue=range_jobs_queue, + queue_stop_object=_STOP, + url_host_names=self.athena_hostnames, + warc_download_prefix=self.warc_download_prefix, + num_fetchers=self.num_readers, + limit=self.record_limit, + ) + ) + else: + raise ValueError(f'Invalid target source: {self.target_source}') + + # Read WARC records based on file paths and ranges + warc_readers = [ + asyncio.create_task( + self.read_warc_records( + reader_id=i, + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + s3_client=read_s3_client, + ) + ) + for i in range(self.num_readers) + ] + + # Write WARC records + warc_writers = [ + asyncio.create_task( + self.write_warc_records( + writer_id=i, + warc_records_queue=warc_records_queue, + s3_client=write_s3_client, + ) + ) + for i in range(self.num_writers) + ] + + # Start writer coordination task + writer_coordinator = asyncio.create_task(self._coordinate_writer_shutdown(warc_readers, warc_records_queue)) + + await job_generators + logger.info('Range jobs submitted, monitoring readers and writers') + + # Wait for all tasks to complete + readers_results = await asyncio.gather(*warc_readers) + writers_results = await asyncio.gather(*warc_writers) + await writer_coordinator + + readers_records = sum([result['stats']['total_records'] for result in readers_results]) + readers_mb_per_sec = self.num_readers * statistics.mean( + [result['stats']['mb_per_sec'] for result in readers_results] + ) + readers_records_per_sec = self.num_readers * statistics.mean( + [result['stats']['records_per_sec'] for result in readers_results] + ) + + logger.info(f'All WARC readers completed: {readers_records} records') + logger.info(f'Total reader throughput: {readers_mb_per_sec:.2f} MB/s; {readers_records_per_sec:.2f} rec/s') + + writers_records = sum([result['stats']['total_records'] for result in writers_results]) + writers_mb_per_sec = self.num_writers * statistics.mean( + [result['stats']['mb_per_sec'] for result in writers_results] + ) + writers_records_per_sec = self.num_writers * statistics.mean( + [result['stats']['records_per_sec'] for result in writers_results] + ) + + logger.info(f'All WARC writers completed: {writers_records} records') + logger.info(f'Total writer throughput: {writers_mb_per_sec:.2f} MB/s; {writers_records_per_sec:.2f} rec/s') + + return writers_records + + async def _coordinate_writer_shutdown(self, warc_readers: List[asyncio.Task], warc_records_queue: asyncio.Queue): + """Coordinate efficient shutdown of writers as readers complete. + + This prevents writers from waiting unnecessarily when all readers are done + and the records queue is being drained. + """ + completed_readers = 0 + + # Monitor reader completion + while completed_readers < len(warc_readers): + # Wait for any reader to complete + done, pending = await asyncio.wait( + warc_readers, + return_when=asyncio.FIRST_COMPLETED, + timeout=1.0, # Check periodically + ) + + if done: + completed_readers = len(warc_readers) - len(pending) + logger.debug(f'Readers completed: {completed_readers}/{len(warc_readers)}') + + # All readers completed - signal writers to stop + logger.info('All readers completed, signaling writers to stop') + + # Send stop signals to all writers + for _ in range(self.num_writers): + await warc_records_queue.put(_STOP) + + async def generate_range_jobs_from_single_cdx( + self, + cdx_path: str, + range_jobs_queue: asyncio.Queue, + count: int = 0, + ) -> int: + """Read a CDX file and generate range jobs based on URLs and offsets.""" + for warc_url, offset, length in iter_cdx_index_from_path( + cdx_path, warc_download_prefix=self.warc_download_prefix + ): + # Convert the CDX record back to a RangeJob + job = RangeJob(url=warc_url, offset=offset, length=length, records_count=1) + await range_jobs_queue.put(job) + count += 1 + + if self.record_limit > 0 and count >= self.record_limit: + logger.warning('Index limit reached at %i', count) + break + + return count + + async def generate_range_jobs_from_cdx( + self, + range_jobs_queue: asyncio.Queue, + s3_client=None, + ): + """Read the CDX paths, parse lines -> RangeJob (WARC files and offets) -> key_queue. + + Args: + range_jobs_queue: Queue to put RangeJob objects into. + s3_client: Optional S3 client for reading CDX indexes from S3. + """ + + logger.info('Range index limit: %i', self.record_limit) + count = 0 + + # Iterate over index files + # TODO this could be done in parallel + for index_path in self.cdx_paths: + # Fetch range queries from index + try: + count += await self.generate_range_jobs_from_single_cdx( + cdx_path=index_path, + range_jobs_queue=range_jobs_queue, + count=count, + ) + + except Exception as e: + logger.error('Failed to read CDX index from %s: %s', index_path, e) + + if self.record_limit > 0 and count >= self.record_limit: + logger.warning('Limit reached at %i', count) + break + + # signal fetchers to stop + for _ in range(self.num_readers): + await range_jobs_queue.put(_STOP) + + logger.info('Enqueued %d jobs from %s', count, index_path) + + async def read_warc_records( + self, + reader_id: int, + range_jobs_queue: asyncio.Queue, + warc_records_queue: asyncio.Queue, + s3_client=None, + ) -> dict: + """Read WARC records based on range jobs -> enqueue RangePayload. + + Args: + reader_id: Unique identifier for this reader task. + range_jobs_queue: Queue to read RangeJob objects from. + warc_records_queue: Queue to put RangePayload objects into. + s3_client: Optional S3 client for reading WARC files from S3. + + Returns: + dict: Statistics dictionary with reader_id and throughput stats. + """ + tracker = ThroughputTracker() + tracker.start() + counter = 0 + + while True: + job = await range_jobs_queue.get() + try: + if job is _STOP: + stats = tracker.get_stats() + logger.info( + 'WARC Reader %d stopping. Stats: %.1fs, %d requests, %.1f MB, %.2f MB/s, %.2f req/s', + reader_id, + stats['elapsed'], + stats['total_requests'], + stats['total_bytes'] / (1024 * 1024), + stats['mb_per_sec'], + stats['requests_per_sec'], + ) + break # Exit loop, but still execute finally block + assert isinstance(job, RangeJob) + data = await job.ranged_get_bytes( + self.max_attempts, + self.base_backoff_seconds, + s3_client=s3_client, + ) + tracker.add(bytes_count=len(data), records_count=job.records_count) + counter += 1 + + # Log progress every N items + self.log_reader(reader_id=reader_id, counter=counter, tracker=tracker) + + await warc_records_queue.put(RangePayload(job=job, data=data)) + except Exception: + logger.exception( + 'WARC Reader %d failed on %s/%s [%d,%d]', + reader_id, + getattr(job, 'bucket', '?'), + getattr(job, 'key', '?'), + getattr(job, 'offset', -1), + getattr(job, 'length', -1), + ) + finally: + range_jobs_queue.task_done() + + return {'reader_id': reader_id, 'stats': tracker.get_stats()} + + async def write_resource_records(self, writer, warcinfo_id: str) -> int: + """Write WARC resource records based on paths""" + resource_records_size = 0 + + logger.info(f'Writing {len(self.write_paths_as_resource_records)} resource records to WARC ... ') + + # Resource records are written at the beginning the WARC file. + for i, resource_record_path in enumerate(self.write_paths_as_resource_records): + logger.info(f'Writing resource record from {resource_record_path} ...') + resource_record = get_resource_record_from_path( + file_path=resource_record_path, + metadata_path=( + self.write_paths_as_resource_records_metadata[i] + if self.write_paths_as_resource_records_metadata + else None + ), + warcinfo_id=warcinfo_id, + ) + record_data = get_bytes_from_warc_record(resource_record, warc_version=self.warc_version, gzip=self.gzip) + await writer.write(record_data) + + # Keep track but do not rotate resource records + resource_records_size += len(record_data) + + logger.info(f'Resource records added: {len(self.write_paths_as_resource_records)}') + + return resource_records_size + + async def write_warc_records( + self, + writer_id: int, + warc_records_queue: asyncio.Queue, + s3_client=None, + ) -> dict: + """Write WARC records. Each writer owns ONE shard MPU and appends ranges to it. + + Args: + writer_id: Unique identifier for this writer task. + warc_records_queue: Queue to read RangePayload objects from. + s3_client: Optional S3 client for writing WARC files to S3. + + Returns: + dict: Statistics dictionary with writer_id and throughput stats. + """ + # File rotation tracking + current_file_sequence = 1 + current_file_size = 0 + + new_writer_kwargs = dict( + s3_client=s3_client, + writer_id=writer_id, + output_path_prefix=self.prefix_path, + max_attempts=self.max_attempts, + base_backoff_seconds=self.base_backoff_seconds, + writer_info=self.writer_info, + warc_version=self.warc_version, + writer_subprefix=self.writer_subprefix, + gzip=self.gzip, + content_type=self.content_type, + min_part_size=self.min_part_size, + ) + + # Initialize first writer with header + writer, header_size, warcinfo_id = await create_new_writer_with_header( + sequence=current_file_sequence, + **new_writer_kwargs, + ) + current_file_size = header_size + + tracker = ThroughputTracker() + tracker.start() + counter = 0 + + # Resource records + if self.write_paths_as_resource_records: + current_file_size += await self.write_resource_records(writer, warcinfo_id=warcinfo_id) + + # Response records + try: + while True: + item = await warc_records_queue.get() + counter += 1 + try: + if item is _STOP: + stats = tracker.get_stats() + logger.info( + 'WARC writer %d stopping. Stats: %.1fs, %d items, %.1f MB written, %.2f MB/s write speed', + writer_id, + stats['elapsed'], + stats['total_requests'], + stats['total_bytes'] / (1024 * 1024), + stats['mb_per_sec'], + ) + should_stop = True + else: + should_stop = False + assert isinstance(item, RangePayload) + + # Check if we need to rotate files due to size limit + writer, current_file_sequence, current_file_size = await self.rotate_files( + writer=writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=len(item.data), + **new_writer_kwargs, + ) + + # Write actual response record + await writer.write(item.data) + current_file_size += len(item.data) + tracker.add(bytes_count=len(item.data), records_count=item.job.records_count) + + # Log progress every N items + self.log_writer(writer_id=writer_id, counter=counter, tracker=tracker) + + except Exception: + logger.exception('WARC writer %d failed on %s', writer_id, getattr(item, 'job', None)) + should_stop = False + finally: + warc_records_queue.task_done() + + if should_stop: + break + finally: + await writer.close() + + return {'writer_id': writer_id, 'stats': tracker.get_stats()} + + def log_reader(self, reader_id: int, counter: int, tracker: ThroughputTracker): + """Log progress every N items.""" + if self.log_every_n > 0 and counter % self.log_every_n == 0: + stats = tracker.get_stats() + logger.info( + 'WARC Reader %d: %d items, %.1f MB, %.2f MB/s, %.2f req/s', + reader_id, + counter, + stats['total_bytes'] / (1024 * 1024), + stats['mb_per_sec'], + stats['requests_per_sec'], + ) + + def log_writer(self, writer_id: int, counter: int, tracker: ThroughputTracker): + """Log progress every N items.""" + if self.log_every_n > 0 and counter % self.log_every_n == 0: + stats = tracker.get_stats() + logger.info( + 'WARC Writer %d: %d items, %.1f MB written, %.2f MB/s', + writer_id, + counter, + stats['total_bytes'] / (1024 * 1024), + stats['mb_per_sec'], + ) + + async def rotate_files( + self, writer, current_file_sequence: int, current_file_size: int, added_byte_size: int, **new_writer_kwargs + ): + """Check if we need to rotate files due to size limit and perform rotation if needed.""" + if self.max_file_size and current_file_size + added_byte_size > self.max_file_size: + await writer.close() + current_file_sequence += 1 + + writer, header_size, warcinfo_id = await create_new_writer_with_header( + sequence=current_file_sequence, + **new_writer_kwargs, + ) + + current_file_size = header_size + logger.info(f'Rotated to new WARC file sequence {current_file_sequence} due to size limit') + + # Resource records also to new files + if self.write_paths_as_resource_records: + current_file_size += await self.write_resource_records(writer, warcinfo_id=warcinfo_id) + + return writer, current_file_sequence, current_file_size diff --git a/cdx_toolkit/filter_warc/warc_utils.py b/cdx_toolkit/filter_warc/warc_utils.py new file mode 100644 index 0000000..e7d4b80 --- /dev/null +++ b/cdx_toolkit/filter_warc/warc_utils.py @@ -0,0 +1,179 @@ +from io import BytesIO +import json +import logging +from pathlib import Path +import fsspec +from warcio.recordloader import ArcWarcRecord +from warcio import WARCWriter + +from typing import Dict, Optional, Tuple, Union + +import mimetypes + +from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri +from cdx_toolkit.filter_warc.local_writer import LocalFileWriter +from cdx_toolkit.filter_warc.s3_writer import S3ShardWriter + +logger = logging.getLogger(__name__) + + +def get_bytes_from_warc_record( + record, + warc_version: str = '1.0', + gzip: bool = False, +): + """Get byte representation of WARC record.""" + buffer = BytesIO() + warc_writer = WARCWriter(buffer, gzip=gzip, warc_version=warc_version) + warc_writer.write_record(record) + + return buffer.getvalue() + + +def get_resource_record_from_path( + file_path: Union[str, Path], + warcinfo_id: str, + metadata_path: Optional[Union[str, Path]] = None, +) -> ArcWarcRecord: + """Build WARC resource record for file path and metdata path. + + The metadata file must be a valid JSON and can have the following fields: + - warc_content_type: str + - uri: str + - http_headers: dict + - warc_headers_dict: str + + If uri is not provided as metadata, the file_path is used. + If warc_content_type is not provided as metadata, the type is guessed. + """ + # Cast to string + file_path = str(file_path) + + with fsspec.open(file_path, 'rb') as f: + file_bytes = BytesIO(f.read()) + + if metadata_path: + # Load metadata from path + metadata_path = str(metadata_path) + + if not metadata_path.endswith('.json'): + raise ValueError('Metadata must be provided JSON (file path ends with *.json)') + + with fsspec.open(metadata_path) as f: + metadata = json.load(f) + + warc_content_type = metadata.get('warc_content_type', None) + uri = metadata.get('uri', None) + http_headers = metadata.get('http_headers', None) + warc_headers_dict = metadata.get('warc_headers_dict', {}) + else: + # Without metdata + warc_content_type = None + uri = None + http_headers = None + warc_headers_dict = {} + + if warc_content_type is None: + warc_content_type = mimetypes.guess_type(file_path)[0] + + if uri is None: + uri = file_path + + # Set WARC-Warcinfo-ID + warc_headers_dict['WARC-Warcinfo-ID'] = warcinfo_id + + return WARCWriter(None).create_warc_record( + uri=uri, + record_type='resource', + payload=file_bytes, + http_headers=http_headers, + warc_content_type=warc_content_type, + warc_headers_dict=warc_headers_dict, + ) + + +def generate_warc_filename( + dest_prefix: str, + writer_id: int, + sequence: int, + writer_subprefix: Optional[str] = None, + gzip: bool = False, +) -> str: + """Generate a WARC file name.""" + file_name = dest_prefix + '-' + if writer_subprefix is not None: + file_name += writer_subprefix + '-' + file_name += '{:06d}-{:03d}'.format(writer_id, sequence) + '.extracted.warc' + if gzip: + file_name += '.gz' + + return file_name + + +async def create_new_writer_with_header( + writer_id: int, + sequence: int, + output_path_prefix: str, + max_attempts: int, + base_backoff_seconds: float, + min_part_size: int, + writer_info: Dict, + warc_version: str = '1.0', + writer_subprefix: Optional[str] = None, + gzip: bool = False, + content_type: Optional[str] = None, + s3_client=None, +) -> Tuple[Union[S3ShardWriter, LocalFileWriter], int, str]: + """Create a new WARC writer (local or S3) including file header.""" + if is_s3_url(output_path_prefix): + dest_bucket, dest_prefix = parse_s3_uri(output_path_prefix) + + filename = generate_warc_filename( + dest_prefix=dest_prefix, + writer_id=writer_id, + sequence=sequence, + writer_subprefix=writer_subprefix, + gzip=gzip, + ) + + new_writer = S3ShardWriter( + s3_client, + filename, + dest_bucket, + content_type, + min_part_size, + max_attempts, + base_backoff_seconds, + ) + + else: + # local file system + filename = generate_warc_filename( + dest_prefix=output_path_prefix, + writer_id=writer_id, + sequence=sequence, + writer_subprefix=writer_subprefix, + gzip=gzip, + ) + + new_writer = LocalFileWriter( + file_path=filename, + ) + + logger.debug('Initialzing new WARC writer for {filename}') + + # Initialize writer + await new_writer.start() + + # Write WARC header + buffer = BytesIO() + warc_writer = WARCWriter(buffer, gzip=gzip, warc_version=warc_version) + warcinfo = warc_writer.create_warcinfo_record(filename, writer_info) + warc_writer.write_record(warcinfo) + header_data = buffer.getvalue() + await new_writer.write(header_data) + + # WARC-Warcinfo-ID indicates the WARC-Record-ID of the associated ‘warcinfo’ record + warcinfo_id = warcinfo.rec_headers.get('WARC-Record-ID') + + return new_writer, len(header_data), warcinfo_id diff --git a/cdx_toolkit/myrequests.py b/cdx_toolkit/myrequests.py index f34b856..6e8d684 100644 --- a/cdx_toolkit/myrequests.py +++ b/cdx_toolkit/myrequests.py @@ -1,9 +1,18 @@ +from typing import Optional import requests import logging import time from urllib.parse import urlparse from . import __version__ +from .settings import ( + DEFAULT_MIN_RETRY_INTERVAL, + CC_DATA_MIN_RETRY_INTERVAL, + CC_INDEX_MIN_RETRY_INTERVAL, + IA_MIN_RETRY_INTERVAL, + MAX_ERRORS, + WARNING_AFTER_N_ERRORS, +) LOGGER = logging.getLogger(__name__) @@ -15,7 +24,7 @@ def dns_fatal(hostname): - '''We have a dns error, should we fail immediately or not?''' + """We have a dns error, should we fail immediately or not?""" if hostname not in previously_seen_hostnames: return True @@ -23,19 +32,19 @@ def dns_fatal(hostname): retry_info = { 'default': { 'next_fetch': 0, - 'minimum_interval': 3.0, + 'minimum_interval': DEFAULT_MIN_RETRY_INTERVAL, }, 'index.commoncrawl.org': { 'next_fetch': 0, - 'minimum_interval': 1.0, + 'minimum_interval': CC_INDEX_MIN_RETRY_INTERVAL, }, 'data.commoncrawl.org': { 'next_fetch': 0, - 'minimum_interval': 0.55, + 'minimum_interval': CC_DATA_MIN_RETRY_INTERVAL, }, 'web.archive.org': { 'next_fetch': 0, - 'minimum_interval': 6.0, + 'minimum_interval': IA_MIN_RETRY_INTERVAL, }, } @@ -43,7 +52,7 @@ def dns_fatal(hostname): def get_retries(hostname): if hostname not in retry_info: retry_info[hostname] = retry_info['default'].copy() - LOGGER.debug('initializing retry info for new host '+hostname) + LOGGER.debug('initializing retry info for new host ' + hostname) entry = retry_info[hostname] if not entry['next_fetch']: entry['next_fetch'] = time.time() @@ -55,17 +64,23 @@ def update_next_fetch(hostname, next_fetch): def myrequests_get( - url, - params=None, - headers=None, - cdx=False, - allow404=False, - raise_error_after_n_errors: int = 100, - raise_warning_after_n_errors: int = 10, + url, + params=None, + headers=None, + cdx=False, + allow404=False, + raise_error_after_n_errors: Optional[int] = None, + raise_warning_after_n_errors: Optional[int] = None, retry_max_sec: int = 60, - ): +): t = time.time() + if raise_error_after_n_errors is None: + raise_error_after_n_errors = MAX_ERRORS + + if raise_warning_after_n_errors is None: + raise_warning_after_n_errors = WARNING_AFTER_N_ERRORS + hostname = urlparse(url).hostname next_fetch, minimum_interval = get_retries(hostname) @@ -89,7 +104,7 @@ def myrequests_get( if headers is None: headers = {} if 'user-agent' not in headers: - headers['User-Agent'] = 'pypi_cdx_toolkit/'+__version__ + headers['User-Agent'] = 'pypi_cdx_toolkit/' + __version__ retry = True retry_sec = 2 * minimum_interval @@ -98,8 +113,7 @@ def myrequests_get( while retry: try: LOGGER.debug('getting %s %r', url, params) - resp = requests.get(url, params=params, headers=headers, - timeout=(30., 30.), allow_redirects=False) + resp = requests.get(url, params=params, headers=headers, timeout=(30.0, 30.0), allow_redirects=False) if cdx and resp.status_code in {400, 404}: # 400: ia html error page -- probably page= is too big -- not an error # 404: pywb {'error': 'No Captures found for: www.pbxxxxxxm.com/*'} -- not an error @@ -120,7 +134,7 @@ def myrequests_get( if resp.text: LOGGER.log(level, 'response body is %s', resp.text) time.sleep(retry_sec) - retry_sec = min(retry_sec*2, retry_max_sec) + retry_sec = min(retry_sec * 2, retry_max_sec) continue if resp.status_code in {400, 404}: # pragma: no cover if resp.text: @@ -128,26 +142,31 @@ def myrequests_get( raise RuntimeError('invalid url of some sort, status={} {}'.format(resp.status_code, url)) resp.raise_for_status() retry = False - except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError, - requests.exceptions.Timeout) as e: + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + requests.exceptions.Timeout, + ) as e: connect_errors += 1 string = '{} failures for url {} {!r}: {}'.format(connect_errors, url, params, str(e)) # Check for DNS errors with different operating systems - if (('Name or service not known' in string) # linux + if ( + ('Name or service not known' in string) # linux or ('nodename nor servname provided, or not known' in string) # macos - or ('getaddrinfo failed' in string)): # windows + or ('getaddrinfo failed' in string) + ): # windows if dns_fatal(url): - raise ValueError('invalid hostname in url '+url) from None + raise ValueError('invalid hostname in url ' + url) from None if connect_errors > raise_error_after_n_errors: LOGGER.error(string) raise ValueError(string) if connect_errors > raise_warning_after_n_errors: LOGGER.warning(string) - LOGGER.info('retrying after {:.2f}s for '.format(retry_max_sec)+str(e)) + LOGGER.info('retrying after {:.2f}s for '.format(retry_max_sec) + str(e)) time.sleep(retry_max_sec) # notice the extra-long sleep - retry_sec = min(retry_sec*2, retry_max_sec) + retry_sec = min(retry_sec * 2, retry_max_sec) except requests.exceptions.RequestException as e: # pragma: no cover LOGGER.warning('something unexpected happened, giving up after %s', str(e)) raise diff --git a/cdx_toolkit/settings.py b/cdx_toolkit/settings.py new file mode 100644 index 0000000..f223f65 --- /dev/null +++ b/cdx_toolkit/settings.py @@ -0,0 +1,15 @@ +import os + +MAX_ERRORS = int(os.environ.get('CDXT_MAX_ERRORS', 100)) +WARNING_AFTER_N_ERRORS = int(os.environ.get('CDXT_WARNING_AFTER_N_ERRORS', 10)) + +DEFAULT_MIN_RETRY_INTERVAL = float(os.environ.get('CDXT_DEFAULT_MIN_RETRY_INTERVAL', 3.0)) +CC_INDEX_MIN_RETRY_INTERVAL = float(os.environ.get('CDXT_CC_INDEX_MIN_RETRY_INTERVAL', 1.0)) +CC_DATA_MIN_RETRY_INTERVAL = float(os.environ.get('CDXT_CC_DATA_MIN_RETRY_INTERVAL', 0.55)) +IA_MIN_RETRY_INTERVAL = float(os.environ.get('CDXT_IA_MIN_RETRY_INTERVAL', 6.0)) + + +def get_mock_time(): + """Get the mock time from environment variable, evaluated dynamically""" + mock_time = os.environ.get('CDXT_MOCK_TIME') + return float(mock_time) if mock_time else None diff --git a/cdx_toolkit/utils.py b/cdx_toolkit/utils.py new file mode 100644 index 0000000..55f0d20 --- /dev/null +++ b/cdx_toolkit/utils.py @@ -0,0 +1,49 @@ +import cdx_toolkit +from cdx_toolkit.commoncrawl import normalize_crawl + +import logging + +LOGGER = logging.getLogger(__name__) + + +def get_version(): + return cdx_toolkit.__version__ + + +def setup(cmd): + kwargs = {} + kwargs['source'] = 'cc' if cmd.crawl else cmd.cc or cmd.ia or cmd.source or None + if kwargs['source'] is None: + raise ValueError('must specify --cc, --ia, or a --source') + if cmd.wb: + kwargs['wb'] = cmd.wb + if cmd.cc_mirror: + kwargs['cc_mirror'] = cmd.cc_mirror + if cmd.crawl: + kwargs['crawl'] = normalize_crawl([cmd.crawl]) # currently a string, not a list + if getattr(cmd, 'warc_download_prefix', None) is not None: + kwargs['warc_download_prefix'] = cmd.warc_download_prefix + + cdx = cdx_toolkit.CDXFetcher(**kwargs) + + kwargs = {} + if cmd.limit: + kwargs['limit'] = cmd.limit + if 'from' in vars(cmd) and vars(cmd)['from']: # python, uh, from is a reserved word + kwargs['from_ts'] = vars(cmd)['from'] + if cmd.to: + kwargs['to'] = cmd.to + if cmd.closest: + if not cmd.get: # pragma: no cover + LOGGER.info('note: --closest works best with --get') + kwargs['closest'] = cmd.closest + if cmd.filter: + kwargs['filter'] = cmd.filter + + if cmd.cmd == 'warc' and cmd.size: + kwargs['size'] = cmd.size + + if cmd.cmd == 'size' and cmd.details: + kwargs['details'] = cmd.details + + return cdx, kwargs diff --git a/cdx_toolkit/warc.py b/cdx_toolkit/warc.py index ac8435e..391e755 100644 --- a/cdx_toolkit/warc.py +++ b/cdx_toolkit/warc.py @@ -5,6 +5,7 @@ import logging import sys +import fsspec from warcio import WARCWriter from warcio.recordloader import ArcWarcRecordLoader from warcio.bufferedreaders import DecompressingBufferedReader @@ -131,10 +132,19 @@ def fetch_warc_record(capture, warc_download_prefix): length = int(capture['length']) warc_url = warc_download_prefix + '/' + filename - headers = {'Range': 'bytes={}-{}'.format(offset, offset+length-1)} - resp = myrequests_get(warc_url, headers=headers) - record_bytes = resp.content + if warc_url.startswith("s3:"): + # fetch from S3 + with fsspec.open(warc_url, 'rb') as f: + f.seek(offset) + record_bytes = f.read(length) + else: + # fetch over HTTP + headers = {'Range': 'bytes={}-{}'.format(offset, offset+length-1)} + + resp = myrequests_get(warc_url, headers=headers) + record_bytes = resp.content + stream = DecompressingBufferedReader(BytesIO(record_bytes)) record = ArcWarcRecordLoader().parse_record_stream(stream) @@ -152,6 +162,9 @@ def fetch_warc_record(capture, warc_download_prefix): class CDXToolkitWARCWriter: + """Writer for WARC files. + + The fsspec package is used for writting to local or remote file system, e.g., S3.""" def __init__(self, prefix, subprefix, info, size=1000000000, gzip=True, warc_version=None): self.prefix = prefix self.subprefix = subprefix @@ -161,6 +174,9 @@ def __init__(self, prefix, subprefix, info, size=1000000000, gzip=True, warc_ver self.warc_version = warc_version self.segment = 0 self.writer = None + self.file_handler = None + self.file_system, self.file_system_prefix = fsspec.url_to_fs(self.prefix) + self._file_context = None def write_record(self, *args, **kwargs): if self.writer is None: @@ -175,21 +191,21 @@ def write_record(self, *args, **kwargs): self.writer.write_record(*args, **kwargs) - fsize = os.fstat(self.fd.fileno()).st_size - if fsize > self.size: - self.fd.close() + # Compare file size of current segment with max. file size + if self.file_handler and self.file_handler.tell() > self.size: + self._close_current_file() self.writer = None self.segment += 1 def _unique_warc_filename(self): while True: - name = self.prefix + '-' + name = self.file_system_prefix + '-' if self.subprefix is not None: name += self.subprefix + '-' name += '{:06d}'.format(self.segment) + '.extracted.warc' if self.gzip: name += '.gz' - if os.path.exists(name): + if self.file_system.exists(name): self.segment += 1 else: break @@ -197,12 +213,24 @@ def _unique_warc_filename(self): def _start_new_warc(self): self.filename = self._unique_warc_filename() - self.fd = open(self.filename, 'wb') + self._file_context = self.file_system.open(self.filename, 'wb') + self.file_handler = self._file_context.__enter__() LOGGER.info('opening new warc file %s', self.filename) - self.writer = WARCWriter(self.fd, gzip=self.gzip, warc_version=self.warc_version) + self.writer = WARCWriter(self.file_handler, gzip=self.gzip, warc_version=self.warc_version) warcinfo = self.writer.create_warcinfo_record(self.filename, self.info) self.writer.write_record(warcinfo) + def _close_current_file(self): + # Close the handler of the current file (needed for fsspec abstraction) + if self._file_context is not None: + self._file_context.__exit__(None, None, None) + self._file_context = None + self.file_handler = None + + def close(self): + # Close the WARC writer (this must be called at the end) + self._close_current_file() + def get_writer(prefix, subprefix, info, **kwargs): return CDXToolkitWARCWriter(prefix, subprefix, info, **kwargs) diff --git a/examples/iter-and-warc.py b/examples/iter-and-warc.py index 73ea3dd..b346d3f 100755 --- a/examples/iter-and-warc.py +++ b/examples/iter-and-warc.py @@ -32,3 +32,5 @@ writer.write_record(record) print(' wrote', url) + +writer.close() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7d3bb76..5599f31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,22 @@ # Install with "python -m pip install -r requirements.txt". # must be kept in sync with setup.py -requests==2.25.1 +requests>=2.25.1 warcio==1.7.4 +fsspec[s3] +surt>=0.3.1 +tqdm>=4.67.1 +url-is-in>=0.1.1 # used by Makefile -pytest==6.2.4 -pytest-cov==2.12.1 -pytest-sugar==0.9.4 -coveralls==3.1.0 -responses==0.25.8 +pytest>=6.2.4 +pytest-cov>=2.12.1 +pytest-sugar>=0.9.4 +coveralls>=3.1.0 +botocore>=1.39.11 +responses>=0.25.8 # packaging -twine==3.4.1 -setuptools==57.0.0 -setuptools-scm==6.0.1 +twine>=3.4.1 +setuptools>=57.0.0 +setuptools-scm>=6.0.1 diff --git a/scripts/cdx_iter b/scripts/cdx_iter index 8b0c5a3..99445c0 100644 --- a/scripts/cdx_iter +++ b/scripts/cdx_iter @@ -143,6 +143,8 @@ elif args.warc: if obj.is_revisit(): LOGGER.warning('revisit record being resolved for url %s %s', url, timestamp) writer.write_record(record) + + writer.close() else: for obj in cdx.iter(args.url, **kwargs): printme = winnow_fields(obj) diff --git a/setup.py b/setup.py index 9f793f8..73f76f2 100755 --- a/setup.py +++ b/setup.py @@ -2,15 +2,13 @@ from os import path -from setuptools import setup +from setuptools import setup, find_packages -packages = [ - 'cdx_toolkit', -] +packages = find_packages(include=['cdx_toolkit*']) # remember: keep requires synchronized with requirements.txt -requires = ['requests', 'warcio'] +requires = ['requests', 'warcio', 'fsspec[s3]', 'aioboto3', 'surt', 'tqdm', 'url-is-in>=0.1.1'] test_requirements = ['pytest', 'pytest-cov', 'responses'] @@ -37,7 +35,7 @@ author_email='lindahl@pbm.com', url='https://github.com/cocrawler/cdx_toolkit', packages=packages, - python_requires=">=3.7", + python_requires=">=3.8", extras_require=extras_require, setup_requires=['setuptools-scm'], install_requires=requires, @@ -60,12 +58,13 @@ 'Programming Language :: Python', #'Programming Language :: Python :: 3.5', # setuptools-scm problem #'Programming Language :: Python :: 3.6', # not offered in github actions - 'Programming Language :: Python :: 3.7', + # 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: Python :: 3 :: Only', ], ) diff --git a/tests/conftest.py b/tests/conftest.py index 2ab1a78..4185719 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,131 @@ -import json import os from pathlib import Path +import pytest +import boto3 +from botocore.config import Config +from botocore.exceptions import NoCredentialsError, ClientError, EndpointConnectionError + +import json import functools from typing import Dict, Optional import requests import responses import base64 +import uuid +import shutil from unittest.mock import patch +TEST_DATA_PATH = Path(__file__).parent / 'data' +TEST_S3_BUCKET = os.environ.get('CDXT_TEST_S3_BUCKET', 'commoncrawl-ci-temp') +DISABLE_S3_TESTS = bool(os.environ.get('CDXT_DISABLE_S3_TESTS', False)) TEST_DATA_BASE_PATH = Path(__file__).parent / 'data' +@pytest.fixture(scope='session', autouse=True) +def cleanup_cache(): + """Delete cache directory before each test to ensure clean state""" + cache_dir = os.path.expanduser('~/.cache/cdx_toolkit/') + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + +@pytest.fixture(scope='session', autouse=True) +def set_mock_time(): + """Set CDXT_MOCK_TIME environment variable for consistent test results""" + # August 15, 2025 - ensures tests use CC-MAIN-2025-33 which exists in mock data + if 'CDXT_MOCK_TIME' not in os.environ: + os.environ['CDXT_MOCK_TIME'] = '1755259200' + + +# Cache for AWS S3/Athena access check to avoid repeated network calls +_aws_s3_access_cache = None +_aws_athena_access_cache = None + + +def check_aws_s3_access(): + """Check if AWS S3 access is available (cached result).""" + global _aws_s3_access_cache + + if _aws_s3_access_cache is not None: + return _aws_s3_access_cache + + try: + config = Config(retries={'max_attempts': 1, 'mode': 'standard'}) + s3_client = boto3.client('s3', config=config) + + # Try list objects on test bucket + s3_client.list_objects_v2(Bucket=TEST_S3_BUCKET, MaxKeys=1) + _aws_s3_access_cache = True + except (NoCredentialsError, ClientError, ConnectionError, EndpointConnectionError): + _aws_s3_access_cache = False + + return _aws_s3_access_cache + + +def requires_aws_s3(func): + """Pytest decorator that skips test if AWS S3 access is not available or disabled.""" + return pytest.mark.skipif(DISABLE_S3_TESTS, reason='AWS S3 access is disabled via environment variable.')( + pytest.mark.skipif( + not check_aws_s3_access(), reason='AWS S3 access not available (no credentials or permissions)' + )(func) + ) + + +def check_aws_athena_access(): + """Check if AWS Athena access is available.""" + global _aws_athena_access_cache + + if _aws_athena_access_cache is not None: + return _aws_athena_access_cache + + try: + client = boto3.client('athena') + + # Try list databasets + client.list_databases(CatalogName='AwsDataCatalog') + _aws_athena_access_cache = True + except (NoCredentialsError, ClientError): + _aws_athena_access_cache = False + + return _aws_athena_access_cache + + +def requires_aws_athena(func): + """Pytest decorator that skips test if AWS Athena access is not available.""" + return pytest.mark.skipif( + not check_aws_athena_access(), reason='AWS Athena access not available (no credentials or permissions)' + )(func) + + +@pytest.fixture +def s3_tmpdir(): + """S3 equivalent of tmpdir - provides a temporary S3 path and handles cleanup.""" + bucket_name = TEST_S3_BUCKET + + # Generate unique prefix using UUID to avoid collisions + temp_prefix = f'cdx_toolkit/ci/tmpdirs/{uuid.uuid4().hex}' + + # Yield the S3 path + yield f's3://{bucket_name}/{temp_prefix}' + + try: + # Cleanup: delete all objects with this prefix + s3_client = boto3.client('s3') + + # List all objects with the temp prefix + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=temp_prefix) + + if 'Contents' in response: + # Delete all objects + objects_to_delete = [{'Key': obj['Key']} for obj in response['Contents']] + s3_client.delete_objects(Bucket=bucket_name, Delete={'Objects': objects_to_delete}) + except (NoCredentialsError, ClientError, ConnectionError, EndpointConnectionError): + # Ignore cleanup errors - test objects will eventually expire + pass + + def flexible_param_matcher(expected_params): """Custom matcher that ignores dynamic 'from' parameter timestamps and casts all values to strings""" @@ -63,8 +176,8 @@ def mock_response_from_jsonl(mock_data_name, mock_data_dir: Optional[str] = None export DISABLE_MOCK_RESPONSES=1 ``` - If the remote APIs change, new mock data can be semi-automatically collected - by setting another environment variable, running corresponding unit tests, + If the remote APIs change, new mock data can be semi-automatically collected + by setting another environment variable, running corresponding unit tests, and overwriting existing mock data in `tests/data/mock_responses`: ```bash @@ -114,28 +227,39 @@ def mock_response_from_jsonl(mock_data_name, mock_data_dir: Optional[str] = None ) -def conditional_mock_responses(func): +def conditional_mock_responses(func=None, *, auto_mock_data: bool = True): """Conditionally applies @responses.activate and auto-loads mock data based on DISABLE_MOCK_RESPONSES env var. - The mock data is automatically loaded from JSONL file from the tests/data directory + The mock data is automatically loaded from JSONL file from the tests/data directory and dependinng on the test module and test function. + + Args: + auto_mock_data: If True, auto-loads test-specific mock data. If False, only loads CC endpoints. """ - # If the flag DISABLE_MOCK_RESPONSES is not detected, response mocking remains enabled - if not os.environ.get('DISABLE_MOCK_RESPONSES'): - # Add responses.activate - func = add_mock_responses(func) + def decorator(f): + # If the flag DISABLE_MOCK_RESPONSES is not detected, response mocking remains enabled + if not os.environ.get('DISABLE_MOCK_RESPONSES'): + # Add responses.activate + f = add_mock_responses(f, auto_mock_data=auto_mock_data) + + if os.environ.get('SAVE_MOCK_RESPONSES'): + # Mock data is saved by capturing output from requests.get + @functools.wraps(f) + def wrapper(*args, **kwargs): + with patch('requests.get', side_effect=_custom_behavior_with_original(requests.get)): + return f(*args, **kwargs) - if os.environ.get('SAVE_MOCK_RESPONSES'): - # Mock data is saved by capturing output from requests.get - @functools.wraps(func) - def wrapper(*args, **kwargs): - with patch('requests.get', side_effect=_custom_behavior_with_original(requests.get)): - return func(*args, **kwargs) + return wrapper - return wrapper + return f - return func + if func is None: + # Called with arguments: @conditional_mock_responses(auto_mock_data=False) + return decorator + else: + # Called without arguments: @conditional_mock_responses + return decorator(func) def save_response_as_mock_data(test_info: str, request_url: str, request_params: Dict, resp, output_base_dir: str): @@ -202,14 +326,15 @@ def custom_behavior(*args, **kwargs): return custom_behavior -def add_mock_responses(func): +def add_mock_responses(func, auto_mock_data: bool = True): @functools.wraps(func) def wrapper(*args, **kwargs): # Load mock data for index calls (same for many test functions) mock_response_from_jsonl('test_get_cc_endpoints', 'test_cc') # Auto-load mock data based on function name - mock_response_from_jsonl(func.__name__, func.__module__.split('.')[-1]) + if auto_mock_data: + mock_response_from_jsonl(func.__name__, func.__module__.split('.')[-1]) return func(*args, **kwargs) return responses.activate(wrapper) diff --git a/tests/data/filter_cdx/whitelist_10_surts.txt b/tests/data/filter_cdx/whitelist_10_surts.txt new file mode 100644 index 0000000..0754ed2 --- /dev/null +++ b/tests/data/filter_cdx/whitelist_10_surts.txt @@ -0,0 +1,10 @@ +com,example)/ +edu,si)/ +com,youtube)/ +gov,archives)/ +gov,census)/ +com,741,onlinedegrees)/online_university_degree_program.html +com,72pines,star)/2007/06/25/%e6%8f%90%e5%8f%96%e5%85%ac%e7%a7%af%e9%87%91/trackback +fr,missiondefrance,bibliotheque)/ +fr,mnhn,biodiv)/fr/taxonomy +fr,mobilierpourchr,wip)/produit/t-837 diff --git a/tests/data/filter_cdx/whitelist_10_urls.txt b/tests/data/filter_cdx/whitelist_10_urls.txt new file mode 100644 index 0000000..475f73d --- /dev/null +++ b/tests/data/filter_cdx/whitelist_10_urls.txt @@ -0,0 +1,10 @@ +example.com +si.edu +youtube.com +archive.gov +census.gov +onlinedegrees.741.com/online_university_degree_program.html +star.72pines.com/2007/06/25/%e6%8f%90%e5%8f%96%e5%85%ac%e7%a7%af%e9%87%91/trackback +bibliotheque.missiondefrance.fr +biodiv.mnhn.fr/fr/taxonomy +wip.mobilierpourchr.fr/produit/t-837 diff --git a/tests/data/filter_cdx/whitelist_11_surts.txt b/tests/data/filter_cdx/whitelist_11_surts.txt new file mode 100644 index 0000000..a2ee272 --- /dev/null +++ b/tests/data/filter_cdx/whitelist_11_surts.txt @@ -0,0 +1,11 @@ +com,example)/ +edu,si)/ +com,youtube)/ +gov,archives)/ +gov,census)/ +com,741,onlinedegrees)/online_university_degree_program.html +com,72pines,star)/2007/06/25/%e6%8f%90%e5%8f%96%e5%85%ac%e7%a7%af%e9%87%91/trackback +fr,missiondefrance,bibliotheque)/ +fr,mnhn,biodiv)/fr/taxonomy +fr,mobilierpourchr,wip)/produit/t-837 +fr,tie-up)/ \ No newline at end of file diff --git a/tests/data/filter_cdx/whitelist_wildcard_urls.txt b/tests/data/filter_cdx/whitelist_wildcard_urls.txt new file mode 100644 index 0000000..0371ef3 --- /dev/null +++ b/tests/data/filter_cdx/whitelist_wildcard_urls.txt @@ -0,0 +1,2 @@ +*.com +*.fr \ No newline at end of file diff --git a/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz b/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz new file mode 100644 index 0000000..47941ed Binary files /dev/null and b/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz differ diff --git a/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.metadata.json b/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.metadata.json new file mode 100644 index 0000000..6a4103f --- /dev/null +++ b/tests/data/warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.metadata.json @@ -0,0 +1,4 @@ +{ + "uri": "filter_cdx.cdx.gz", + "warc_content_type": "application/cdx" +} \ No newline at end of file diff --git a/tests/filter_cdx/__init__.py b/tests/filter_cdx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/filter_cdx/test_command.py b/tests/filter_cdx/test_command.py new file mode 100644 index 0000000..d854b5d --- /dev/null +++ b/tests/filter_cdx/test_command.py @@ -0,0 +1,126 @@ +import pytest + + +from cdx_toolkit.cli import main +from tests.conftest import TEST_DATA_PATH, requires_aws_s3 + +fixture_path = TEST_DATA_PATH / 'filter_cdx' + + +@requires_aws_s3 +def test_cli_filter_cdx_from_s3_with_parallel_processing(tmpdir, caplog): + """Test that parallel processing works correctly and processes multiple files.""" + index_path = 's3://commoncrawl/cc-index/collections' + index_glob = '/CC-MAIN-2024-30/indexes/cdx-0018[78].gz' # Multiple files pattern + whitelist_path = fixture_path / 'whitelist_11_surts.txt' # Additonal entry for cdx-00188.gz + + # Run with parallel processing (2 workers) + main( + args=[ + '-v', + '--limit=10', + 'filter_cdx', + f'{index_path}', + f'{str(whitelist_path)}', + f'{tmpdir}', + '--filter-type=surt', + f'--input-glob={index_glob}', + '--parallel=2', + ] + ) + + # Check that multiple files were processed in parallel + assert 'Found' in caplog.text and 'files matching pattern' in caplog.text + assert 'File statistics' in caplog.text + assert 'Filter statistics' in caplog.text + + # Should have processed multiple files (pattern matches 2 files: cdx-00187.gz and cdx-00188.gz) + file_stats_count = caplog.text.count('File statistics') + assert file_stats_count == 2, 'Should process exactly 2 files with the glob pattern' + + +def test_filter_cdx_nonexistent_surt_file_exits(tmpdir, caplog): + index_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + nonexistent_surt_file_name = 'nonexistent_surts.txt' + nonexistent_surt_file = str(tmpdir / nonexistent_surt_file_name) + + # Test that the command exits when SURT file doesn't exist + with pytest.raises(SystemExit) as exc_info: + main( + args=[ + '-v', + '--limit=1140', + 'filter_cdx', + f'{str(index_path)}', + f'{nonexistent_surt_file}', + f'{tmpdir}', + '--overwrite', + ] + ) + + assert exc_info.value.code == 1 + assert 'Filter file not found: ' in caplog.text + assert nonexistent_surt_file_name in caplog.text + + +def test_cli_filter_cdx_with_wildcard_urls(tmpdir, caplog): + # check if expected number is reached + index_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + whitelist_path = fixture_path / 'whitelist_wildcard_urls.txt' # matches on all .com and .fr host names + + main( + args=[ + '-v', + '--limit=10', + 'filter_cdx', + f'{index_path}', + f'{str(whitelist_path)}', + f'{tmpdir}', + '--filter-type=url', + '--overwrite', + ] + ) + + assert 'Limit reached' in caplog.text + + +def test_cli_filter_cdx_with_urls(tmpdir, caplog): + # check if expected number is reached + index_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + whitelist_path = fixture_path / 'whitelist_10_urls.txt' # matches on first domain and after 100k and 200k lines + + main( + args=[ + '-v', + '--limit=1140', + 'filter_cdx', + f'{index_path}', + f'{str(whitelist_path)}', + f'{tmpdir}', + '--filter-type=url', + '--overwrite', + ] + ) + + assert 'Limit reached' in caplog.text + + +def test_cli_filter_cdx_with_surts(tmpdir, caplog): + # check if expected number is reached + index_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + whitelist_path = fixture_path / 'whitelist_10_surts.txt' # matches on first domain and after 100k and 200k lines + + main( + args=[ + '-v', + '--limit=1140', + 'filter_cdx', + f'{index_path}', + f'{str(whitelist_path)}', + f'{tmpdir}', + '--filter-type=surt', + '--overwrite', + ] + ) + + assert 'Limit reached' in caplog.text diff --git a/tests/filter_cdx/test_filter_cdx.py b/tests/filter_cdx/test_filter_cdx.py new file mode 100644 index 0000000..c2c4ed2 --- /dev/null +++ b/tests/filter_cdx/test_filter_cdx.py @@ -0,0 +1,168 @@ +import multiprocessing +import signal +import time + +from unittest.mock import patch, MagicMock + + +from url_is_in import SURTMatcher + +from cdx_toolkit.filter_cdx.cdx_filter import _filter_single_cdx_file, _filter_single_cdx_file_args, filter_cdx +from tests.conftest import TEST_DATA_PATH + +fixture_path = TEST_DATA_PATH / 'filter_cdx' + + +def test_filter_single_file(tmpdir): + input_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + matcher = SURTMatcher(['fr,']) + args = dict( + input_path=input_path, + output_path=tmpdir + '/filter_cdx', + matcher=matcher, + log_every_n=10, + limit=100, + ) + _, _, lines_n, included_n, errors_n = _filter_single_cdx_file_args(args) + + assert included_n == 100 + assert lines_n == 100 + assert errors_n == 0 + + +def test_filter_single_file_empty(tmpdir): + input_path = tmpdir + '/input' + with open(input_path, 'w') as f: + f.write('') + + _, _, lines_n, included_n, errors_n = _filter_single_cdx_file( + input_path=input_path, + output_path=tmpdir + '/output', + matcher=None, + ) + assert lines_n == 0 + assert included_n == 0 + assert errors_n == 0 + + +def test_filter_single_cdx_file_input_not_found(tmpdir): + + _, _, lines_n, included_n, errors_n = _filter_single_cdx_file( + input_path=tmpdir + "/input-not-found", + output_path=tmpdir + '/output', + matcher=SURTMatcher([]), + ) + assert lines_n == 0 + assert included_n == 0 + assert errors_n == 1, 'Invalid error count' + + +def test_filter_single_cdx_file_with_matcher_error(tmpdir): + class MockMatcher(SURTMatcher): + def is_in(self, surt): + raise ValueError() + + mock_matcher = MockMatcher([]) + input_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + + _, _, lines_n, included_n, errors_n = _filter_single_cdx_file( + input_path=input_path, + output_path=tmpdir + '/output', + matcher=mock_matcher, + ) + assert lines_n == 1140 + assert included_n == 0 + assert errors_n == 1140, 'Invalid error count' + + +def test_filter_cdx_error_handling(tmpdir, caplog): + """Test filter_cdx function error handling when exceptions occur during processing.""" + import multiprocessing + + # Store original start method to restore later + original_start_method = multiprocessing.get_start_method() + + try: + # Force fork method for consistent behavior across platforms + multiprocessing.set_start_method('fork', force=True) + + def mock_filter_single_file(*args, **kwargs): + raise ValueError() + + # Create test input and output paths + input_paths = [str(tmpdir / 'input1.cdx'), str(tmpdir / 'input2.cdx')] + output_paths = [str(tmpdir / 'output1.cdx'), str(tmpdir / 'output2.cdx')] + + # Replace the _process_single_file function with our mock + with patch('cdx_toolkit.filter_cdx.cdx_filter._filter_single_cdx_file', side_effect=mock_filter_single_file): + # Test the error handling + total_lines, total_included, total_errors = filter_cdx( + matcher=None, + input_paths=input_paths, + output_paths=output_paths, + ) + + # Verify error handling results + assert total_errors == 1, f'Should have 1 error from the first failed file, got {total_errors}' + assert total_lines == 0, 'Should have lines from the successful file' + assert total_included == 0, 'Should have included lines from the successful file' + + # Check that error was logged correctly + assert 'Error during parallel processing' in caplog.text + finally: + # Restore original start method + multiprocessing.set_start_method(original_start_method, force=True) + + +def test_filter_cdx_keyboard_interrupt_handling(tmpdir, caplog): + """Test that filter_cdx properly handles KeyboardInterrupt and terminates the pool.""" + # Store original start method to restore later + original_start_method = multiprocessing.get_start_method() + + try: + # Force fork method for consistent behavior across platforms + multiprocessing.set_start_method('fork', force=True) + + def slow_filter_single_file(*args, **kwargs): + """Mock function that simulates a slow process that can be interrupted.""" + time.sleep(1) # Simulate slow processing + return args[0], args[1], 10, 5, 0 # Return some dummy stats + + # Create test input and output paths + input_paths = [str(tmpdir / 'input1.cdx'), str(tmpdir / 'input2.cdx')] + output_paths = [str(tmpdir / 'output1.cdx'), str(tmpdir / 'output2.cdx')] + + # Set caplog to capture INFO level messages + caplog.set_level('INFO') + + # Mock the Pool class to allow us to verify terminate() and join() are called + with patch('cdx_toolkit.filter_cdx.cdx_filter.Pool') as mock_pool_class: + mock_pool = MagicMock() + mock_pool_class.return_value = mock_pool + + # Make imap raise KeyboardInterrupt after a short delay + def interrupt_after_delay(*args, **kwargs): + time.sleep(0.1) # Brief delay before interrupt + raise KeyboardInterrupt() + + mock_pool.imap.side_effect = interrupt_after_delay + + # Test the keyboard interrupt handling + total_lines, total_included, total_errors = filter_cdx( + matcher=None, input_paths=input_paths, output_paths=output_paths, n_parallel=2 + ) + + # Verify that pool.terminate() and pool.join() were called + mock_pool.terminate.assert_called_once() + mock_pool.join.assert_called() + + # Verify that the interrupt was logged + assert 'Process interrupted by user (Ctrl+C). Terminating running tasks...' in caplog.text + assert 'All tasks terminated.' in caplog.text + + # Verify pool cleanup in finally block + mock_pool.close.assert_called_once() + + finally: + # Restore original start method + multiprocessing.set_start_method(original_start_method, force=True) diff --git a/tests/filter_cdx/test_path_utils.py b/tests/filter_cdx/test_path_utils.py new file mode 100644 index 0000000..f1fddc7 --- /dev/null +++ b/tests/filter_cdx/test_path_utils.py @@ -0,0 +1,84 @@ +import tempfile +from cdx_toolkit.filter_cdx.path_utils import resolve_paths, validate_resolved_paths +from tests.conftest import TEST_S3_BUCKET, requires_aws_s3 + +import pytest +import os + + +def test_resolve_s3_paths_without_glob(): + input_files, output_files = resolve_paths( + input_base_path="s3://commoncraw/cc-index/collections/CC-MAIN-2016-30/indexes/cdx-00001.gz", + input_glob=None, + output_base_path=f"s3://{TEST_S3_BUCKET}/output", + ) + assert len(input_files) == 1 + assert len(output_files) == len(input_files) + + +def test_validate_resolved_paths_with_makedirs(): + with tempfile.TemporaryDirectory() as tmpdir: + validate_resolved_paths( + output_paths=[ + os.path.join(tmpdir, "1"), + os.path.join(tmpdir, "2"), + ], + overwrite=False + ) + + +@requires_aws_s3 +def test_resolve_cdx_paths_from_cc_s3_to_local(tmpdir): + tmpdir = str(tmpdir) + base_path = 's3://commoncrawl/cc-index/collections' + glob_pattern = '/CC-MAIN-2016-30/indexes/*.gz' + + input_files, output_files = resolve_paths(base_path, glob_pattern, output_base_path=tmpdir) + + assert len(input_files) == len(output_files), 'Input and output count must be the same' + assert len(input_files) == 300, 'Invalid input count' + assert input_files[0] == base_path + '/CC-MAIN-2016-30/indexes/cdx-00000.gz', 'Invalid input file' + assert output_files[0] == tmpdir + '/CC-MAIN-2016-30/indexes/cdx-00000.gz'.replace('/', os.sep), ( + 'Invalid output file' + ) + assert input_files[-1] == base_path + '/CC-MAIN-2016-30/indexes/cdx-00299.gz' + + +@requires_aws_s3 +def test_resolve_cdx_paths_from_cc_s3_to_another_s3(): + output_base_path = 's3://some-other-bucket/filter-cdx' + base_path = 's3://commoncrawl/cc-index/collections' + glob_pattern = '/CC-MAIN-2016-30/indexes/cdx-000*.gz' + + input_files, output_files = resolve_paths(base_path, glob_pattern, output_base_path=output_base_path) + + assert len(input_files) == len(output_files), 'Input and output count must be the same' + assert len(input_files) == 100, 'Invalid input count' + assert input_files[0] == base_path + '/CC-MAIN-2016-30/indexes/cdx-00000.gz', 'Invalid input file' + assert output_files[0] == output_base_path + '/CC-MAIN-2016-30/indexes/cdx-00000.gz', 'Invalid output file' + assert input_files[-1] == base_path + '/CC-MAIN-2016-30/indexes/cdx-00099.gz' + + +def test_resolve_paths_no_files_found_exits(tmpdir, caplog): + # Test that resolve_paths exits when no files match the glob pattern + with pytest.raises(SystemExit) as exc_info: + resolve_paths(input_base_path=str(tmpdir), input_glob='/nonexistent-pattern-*.gz', output_base_path=str(tmpdir)) + + assert exc_info.value.code == 1 + assert 'No files found matching glob pattern:' in caplog.text + + +def test_validate_resolved_paths_existing_file_exits(tmpdir, caplog): + # Create an existing output file + existing_file = tmpdir / 'existing_output.txt' + existing_file.write_text('existing content', encoding='utf-8') + + output_paths = [str(existing_file)] + + # Test that validate_resolved_paths exits when output file exists and overwrite=False + with pytest.raises(SystemExit) as exc_info: + validate_resolved_paths(output_paths, overwrite=False) + + assert exc_info.value.code == 1 + assert f'Output file already exists: {str(existing_file)}' in caplog.text + assert 'Use --overwrite to overwrite existing files' in caplog.text \ No newline at end of file diff --git a/tests/filter_warc/__init__.py b/tests/filter_warc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/filter_warc/test_athena_job_generator.py b/tests/filter_warc/test_athena_job_generator.py new file mode 100644 index 0000000..1ddb32c --- /dev/null +++ b/tests/filter_warc/test_athena_job_generator.py @@ -0,0 +1,69 @@ +import asyncio +from cdx_toolkit.filter_warc.warc_filter import _STOP +from cdx_toolkit.filter_warc.athena_job_generator import get_databases, get_range_jobs_from_athena +from tests.conftest import requires_aws_athena + +import boto3 + + +@requires_aws_athena +def test_get_databases(): + from botocore.config import Config + import boto3 + + boto_cfg = Config( + region_name='us-east-1', + ) + athena_client = boto3.client('athena', config=boto_cfg) + dbs = get_databases(client=athena_client) + assert 'ccindex' in dbs + + +@requires_aws_athena +def test_get_range_jobs_from_athena(): + async def run_test(): + # Setup test data + warc_download_prefix = 's3://commoncrawl' + + # Create asyncio queues + key_queue = asyncio.Queue() + + # Setup S3 client + from botocore.config import Config + + boto_cfg = Config( + region_name='us-east-1', + retries={'max_attempts': 3, 'mode': 'standard'}, + connect_timeout=10, + read_timeout=120, + ) + + athena_client = boto3.client('athena', config=boto_cfg) + + # Generate range jobs from Athena query + await get_range_jobs_from_athena( + client=athena_client, + database="ccindex", + s3_output_location="s3://commoncrawl-ci-temp/athena-results/", + url_host_names=[ + 'oceancolor.sci.gsfc.nasa.gov', + ], + job_queue=key_queue, + warc_download_prefix=warc_download_prefix, + num_fetchers=1, + limit=10, # Use 10 records to ensure we have enough data + queue_stop_object=_STOP, + ) + + # Collect all range jobs + range_jobs = [] + while not key_queue.empty(): + job = await key_queue.get() + if job is not _STOP: + range_jobs.append(job) + key_queue.task_done() + + assert len(range_jobs) == 10, "Invalid range jobs count" + + # Run the async test + asyncio.run(run_test()) diff --git a/tests/filter_warc/test_cdx_utils.py b/tests/filter_warc/test_cdx_utils.py new file mode 100644 index 0000000..378bf85 --- /dev/null +++ b/tests/filter_warc/test_cdx_utils.py @@ -0,0 +1,75 @@ +import fsspec +import pytest +from cdx_toolkit.filter_warc.cdx_utils import get_index_as_string_from_path, read_cdx_line, iter_cdx_index_from_path +from tests.conftest import TEST_DATA_PATH + +import tempfile +import gzip +import os +from unittest.mock import patch + + +def test_get_index_as_string_from_path(): + cdx_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + + index = get_index_as_string_from_path(cdx_path) + + assert len(index) == 568010 + + +def test_get_index_as_string_from_path_with_fs(): + fs, cdx_path = fsspec.url_to_fs(TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz') + + index = get_index_as_string_from_path(cdx_path, fs) + + assert len(index) == 568010 + + +def test_read_cdx_line_error(): + with pytest.raises(ValueError) as ec_info: + read_cdx_line('this_is_a_bad_CDX-line', warc_download_prefix='http://') + + assert ec_info.match('Cannot parse line') + + +def test_iter_cdx_index_from_path_with_error(): + """Test iter_cdx_index_from_path error handling when read_cdx_line raises exception.""" + + # Create a temporary CDX file with mixed valid and invalid lines + test_cdx_content = """ +org,example)/ 20240101120000 {"url": "http://example.org/", "filename": "test.warc.gz", "offset": "100", "length": "500"} +invalid_line_here_that_will_cause_error +org,test)/ 20240102130000 {"url": "http://test.org/", "filename": "test2.warc.gz", "offset": "600", "length": "300"} +another_bad_line +org,valid)/ 20240103140000 {"url": "http://valid.org/", "filename": "test3.warc.gz", "offset": "900", "length": "200"} +""".strip() + + fd, tmp_file_path = tempfile.mkstemp(suffix='.cdx.gz') + try: + os.close(fd) # Close the file descriptor + + # Write gzipped CDX content + with gzip.open(tmp_file_path, 'wt') as f: + f.write(test_cdx_content) + + # Mock read_cdx_line to raise exception for invalid lines + original_read_cdx_line = read_cdx_line + + def mock_read_cdx_line(line, warc_download_prefix): + if 'invalid' in line or 'bad' in line: + raise ValueError(f'Mock error for line: {line}') + return original_read_cdx_line(line, warc_download_prefix) + + with patch('cdx_toolkit.filter_warc.cdx_utils.read_cdx_line', side_effect=mock_read_cdx_line): + # Collect results from iterator + results = list(iter_cdx_index_from_path(tmp_file_path, 'http://warc-prefix')) + + # Should have 3 valid results despite 2 invalid lines being skipped + assert len(results) == 3 + + # Verify the valid results + assert results[0] == ('http://warc-prefix/test.warc.gz', 100, 500) + assert results[1] == ('http://warc-prefix/test2.warc.gz', 600, 300) + assert results[2] == ('http://warc-prefix/test3.warc.gz', 900, 200) + finally: + os.unlink(tmp_file_path) diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py new file mode 100644 index 0000000..10e3b8b --- /dev/null +++ b/tests/filter_warc/test_command.py @@ -0,0 +1,346 @@ +import os +from typing import List, Optional + +import fsspec +from cdx_toolkit.cli import main + +import pytest +from warcio.archiveiterator import ArchiveIterator + +from tests.conftest import requires_aws_s3, TEST_DATA_PATH + + +fixture_path = TEST_DATA_PATH / 'warc_by_cdx' + + +def assert_cli_warc_by_cdx( + warc_download_prefix, + base_prefix, + caplog, + extra_args: Optional[List[str]] = None, + # warc_filename: str = 'TEST_warc_by_index-000000.extracted.warc.gz', + warc_filename: str = 'TEST_warc_by_index-000000-001.extracted.warc.gz', # due to parallel writer +): + # test cli and check output + index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' + resource_record_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + + base_prefix = str(base_prefix) + + if extra_args is None: + extra_args = [] + + main( + args=[ + '-v', + '--limit=10', + 'warc_by_cdx', + f'--cdx-path={str(index_path)}', + '--write-paths-as-resource-records', + str(resource_record_path), + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + f'--warc-download-prefix={warc_download_prefix}', + ] + + extra_args + ) + + # Check log + assert 'Limit reached' in caplog.text + + # Validate extracted WARC + if 's3:' in base_prefix: + warc_path = base_prefix + '/' + warc_filename + else: + warc_path = os.path.join(base_prefix, warc_filename) + + info_record = None + response_records = [] + response_contents = [] + + resource_record = None + resource_record_content = None + + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode('utf-8') + + if record.rec_type == 'response': + response_records.append(record) + response_contents.append(record.content_stream().read().decode('utf-8', errors='ignore')) + + if record.rec_type == 'resource': + resource_record = record + resource_record_content = record.content_stream().read().decode('utf-8') + + assert len(response_records) == 10, 'Invalid record count' + + assert info_record is not None, 'Invalid info record' + assert 'operator: bob' in info_record, 'Invalid info record' + + assert 'Catalogue en ligne Mission de France' in response_contents[0], 'Invalid response content' + assert 'dojo/dijit/themes/tundra/tundra' in response_contents[9], 'Invalid response content' + + assert resource_record is not None, 'Resource record not set' + + assert resource_record_content[:10] == 'example.co', 'Invalid resource record' + + # Disabled due to OS-specific line endings + # assert resource_record_content[-20:-1] == 'hr.fr/produit/t-837', 'Invalid resource record' + + # Calculate expected length based on the actual source file on current OS + with open(resource_record_path, 'rb') as f: + expected_length = len(f.read()) + + assert resource_record.length == expected_length, ( + f'Invalid resource record length {resource_record.length}, expected {expected_length} ' + f'(computed from {resource_record_path} on current OS)' + ) + + +def test_cli_warc_by_cdx_over_http(tmpdir, caplog): + assert_cli_warc_by_cdx('https://data.commoncrawl.org', base_prefix=tmpdir, caplog=caplog) + + +def test_cli_warc_by_cdx_over_http_in_parallel(tmpdir, caplog): + assert_cli_warc_by_cdx( + 'https://data.commoncrawl.org', base_prefix=tmpdir, caplog=caplog, extra_args=['--parallel=3'] + ) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_s3(tmpdir, caplog): + assert_cli_warc_by_cdx('s3://commoncrawl', base_prefix=tmpdir, caplog=caplog) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_s3_to_s3(s3_tmpdir, caplog): + assert_cli_warc_by_cdx( + 's3://commoncrawl', + base_prefix=s3_tmpdir, + caplog=caplog, + ) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_s3_to_s3_in_parallel(s3_tmpdir, caplog): + assert_cli_warc_by_cdx( + 's3://commoncrawl', + base_prefix=s3_tmpdir, + caplog=caplog, + extra_args=['--parallel=3', '--is-part-of=foobar'], + ) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_s3_to_s3_warc_filter(s3_tmpdir, caplog): + assert_cli_warc_by_cdx( + 's3://commoncrawl', + base_prefix=s3_tmpdir, + caplog=caplog, + ) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_http_to_s3_in_parallel(s3_tmpdir, caplog): + assert_cli_warc_by_cdx( + 'https://data.commoncrawl.org', + base_prefix=s3_tmpdir, + caplog=caplog, + extra_args=[ + '--parallel=3', + ], + ) + + +@requires_aws_s3 +def test_cli_warc_by_cdx_over_s3_to_local_in_parallel(tmpdir, caplog): + assert_cli_warc_by_cdx( + 's3://commoncrawl', + base_prefix=tmpdir, + caplog=caplog, + extra_args=[ + '--parallel=3', + ], + ) + + +def test_warc_by_cdx_no_index_files_found_exits(tmpdir, caplog): + # Test that warc_by_cdx exits when no index files match the glob pattern + with pytest.raises(SystemExit) as exc_info: + main( + args=[ + '-v', + 'warc_by_cdx', + f'--cdx-path={str(tmpdir)}', + f'--prefix={str(tmpdir)}/TEST', + '--cdx-glob=/nonexistent-pattern-*.gz', + ] + ) + + assert exc_info.value.code == 1 + assert 'no index files found' in caplog.text + + +def test_warc_by_cdx_subprefix_and_metadata(tmpdir): + # Test subprefix functionality and creator/operator metadata + index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' + + main( + args=[ + '-v', + '--limit=1', + 'warc_by_cdx', + f'--cdx-path={str(index_path)}', + f'--prefix={str(tmpdir)}/TEST', + '--subprefix=SUB', + '--creator=test_creator', + '--operator=test_operator', + ] + ) + + # Check that WARC file was created with subprefix + warc_path = os.path.join(tmpdir, 'TEST-SUB-000000-001.extracted.warc.gz') + assert os.path.exists(warc_path) + + # Validate metadata in warcinfo record + info_record = None + with open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode('utf-8') + break + + assert info_record is not None + assert 'creator: test_creator' in info_record + assert 'operator: test_operator' in info_record + + +def test_warc_by_cdx_without_creator_operator(tmpdir): + # Test that creator and operator are optional (lines 44-47) + index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' + + main( + args=[ + '-v', + '--limit=1', + 'warc_by_cdx', + f'--cdx-path={str(index_path)}', + f'--prefix={str(tmpdir)}/TEST_NO_META', + ] + ) + + # Check that WARC file was created + warc_path = os.path.join(tmpdir, 'TEST_NO_META-000000-001.extracted.warc.gz') + assert os.path.exists(warc_path) + + # Validate that creator/operator are not in warcinfo record + info_record = None + with open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode('utf-8') + break + + assert info_record is not None + assert 'creator:' not in info_record + assert 'operator:' not in info_record + + +def test_resource_records_paths_mismatch(): + # Test if mismatch of number of paths for resource records and their metdata is raised. + with pytest.raises(ValueError) as exc_info: + main( + args=[ + '-v', + 'warc_by_cdx', + '--cdx-path=foo/bar', + '--write-paths-as-resource-records', + 'resource1', + 'resource2', + '--write-paths-as-resource-records-metadata', + 'metadata2', + ] + ) + assert exc_info.match('Number of paths to resource records') + + +def test_metadata_paths_without_resource_records_paths(): + # Test if error of missing resource records paths is raised. + with pytest.raises(ValueError) as exc_info: + main( + args=['-v', 'warc_by_cdx', '--cdx-path=foo/bar', '--write-paths-as-resource-records-metadata', 'metadata2'] + ) + assert exc_info.match('Metadata paths are set but') + + +def test_cli_warc_by_athena( + tmpdir, + caplog, +): + base_prefix = tmpdir + warc_download_prefix = 's3://commoncrawl' + extra_args: Optional[List[str]] = None + warc_filename: str = 'TEST_warc_by_index-000000-001.extracted.warc.gz' # due to parallel writer + base_prefix = str(base_prefix) + + if extra_args is None: + extra_args = [] + + main( + args=[ + '-v', + '--limit=10', + 'warc_by_cdx', + '--target-source=athena', + '--athena-database=ccindex', + '--athena-s3-output=s3://commoncrawl-ci-temp/athena-results/', + '--athena-hostnames', + 'oceancolor.sci.gsfc.nasa.gov', + 'example.com', + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + f'--warc-download-prefix={warc_download_prefix}', + ] + + extra_args + ) + + # Check log + assert 'WARC records extracted: 10' in caplog.text + + # Validate extracted WARC + if 's3:' in base_prefix: + warc_path = base_prefix + '/' + warc_filename + else: + warc_path = os.path.join(base_prefix, warc_filename) + + info_record = None + response_records = [] + response_contents = [] + + # resource_record = None + # resource_record_content = None + + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode('utf-8') + + if record.rec_type == 'response': + response_records.append(record) + response_contents.append(record.content_stream().read().decode('utf-8', errors='ignore')) + + # if record.rec_type == 'resource': + # resource_record = record + # resource_record_content = record.content_stream().read().decode('utf-8') + + assert len(response_records) == 10, 'Invalid record count' + + assert info_record is not None, 'Invalid info record' + assert 'operator: bob' in info_record, 'Invalid info record' + + assert '

Example Domain

' in response_contents[0], 'Invalid response content' + assert '

Example Domain

' in response_contents[9], 'Invalid response content' diff --git a/tests/filter_warc/test_data_classes.py b/tests/filter_warc/test_data_classes.py new file mode 100644 index 0000000..f86d96d --- /dev/null +++ b/tests/filter_warc/test_data_classes.py @@ -0,0 +1,12 @@ +import pytest +from cdx_toolkit.filter_warc.data_classes import RangeJob + + +def test_get_s3_bucket_and_key_from_http_job(): + job = RangeJob( + url='http://foo.com/example', + offset=0, + length=10, + ) + with pytest.raises(ValueError): + job.get_s3_bucket_and_key() diff --git a/tests/filter_warc/test_grouped_range_jobs.py b/tests/filter_warc/test_grouped_range_jobs.py new file mode 100644 index 0000000..adf9a07 --- /dev/null +++ b/tests/filter_warc/test_grouped_range_jobs.py @@ -0,0 +1,51 @@ +import fsspec +import pytest +from cdx_toolkit.filter_warc.cdx_utils import get_index_as_string_from_path, read_cdx_line, iter_cdx_index_from_path +from tests.conftest import TEST_DATA_PATH + +import tempfile +import gzip +import os +from unittest.mock import patch + + + +def test_iter_cdx_index_from_test_data(): + cdx_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + results = list(iter_cdx_index_from_path(str(cdx_path), 'http://warc-prefix')) + # [(url, offset, length)] + + # sort results by offsets + results.sort(key=lambda x: x[1]) + + # group into neighbor chunks + def group_neighbor_chunks(items): + """Group items into chunks where items have same URL and are contiguous.""" + if not items: + return [] + + chunks = [] + current_chunk = [items[0]] + + for i in range(1, len(items)): + prev_url, prev_offset, prev_length = items[i-1] + curr_url, curr_offset, curr_length = items[i] + + # Check if current item is a neighbor (same URL and contiguous) + if curr_url == prev_url and curr_offset == prev_offset + prev_length + 4: + current_chunk.append(items[i]) + else: + # Start new chunk + chunks.append(current_chunk) + current_chunk = [items[i]] + + # Add the last chunk + chunks.append(current_chunk) + return chunks + + grouped_chunks = group_neighbor_chunks(results) + print(len(results), len(grouped_chunks)) + + +def test_grouped_ranges(): + cdx_path = "" \ No newline at end of file diff --git a/tests/filter_warc/test_local_writer.py b/tests/filter_warc/test_local_writer.py new file mode 100644 index 0000000..36a88c5 --- /dev/null +++ b/tests/filter_warc/test_local_writer.py @@ -0,0 +1,454 @@ +import asyncio +import pytest +from unittest.mock import patch + +from cdx_toolkit.filter_warc.local_writer import LocalFileWriter + + +def test_init_default_values(): + """Test initialization with default values.""" + writer = LocalFileWriter("/tmp/test.txt") + assert writer.file_path == "/tmp/test.txt" + assert writer.buffer_size == 8192 + assert writer.mode == 'wb' + assert writer.file_handle is None + assert isinstance(writer.buffer, bytearray) + assert len(writer.buffer) == 0 + + +def test_init_custom_values(): + """Test initialization with custom values.""" + writer = LocalFileWriter("/tmp/test.txt", buffer_size=4096, mode='ab') + assert writer.file_path == "/tmp/test.txt" + assert writer.buffer_size == 4096 + assert writer.mode == 'ab' + assert writer.file_handle is None + assert isinstance(writer.buffer, bytearray) + assert len(writer.buffer) == 0 + + +def test_start_opens_file(tmp_path): + """Test that start() opens the file correctly.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + assert writer.file_handle is not None + await writer.close() + + asyncio.run(run_test()) + + +def test_start_with_different_modes(tmp_path): + """Test start() with different file modes.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + + # Test binary write mode + writer = LocalFileWriter(str(temp_file), mode='wb') + await writer.start() + assert writer.file_handle is not None + await writer.close() + + # Test binary append mode + writer = LocalFileWriter(str(temp_file), mode='ab') + await writer.start() + assert writer.file_handle is not None + await writer.close() + + asyncio.run(run_test()) + + +def test_start_creates_directory_if_needed(tmp_path): + """Test that start() works when parent directory exists.""" + async def run_test(): + subdir = tmp_path / "subdir" + subdir.mkdir() + temp_file = subdir / "test.txt" + + writer = LocalFileWriter(str(temp_file)) + await writer.start() + assert writer.file_handle is not None + await writer.close() + + asyncio.run(run_test()) + + +def test_write_small_data_buffers(tmp_path): + """Test writing data that doesn't exceed buffer size.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file), buffer_size=100) + await writer.start() + + test_data = b"Hello, World!" + await writer.write(test_data) + + # Data should be in buffer, not yet written to file + assert len(writer.buffer) == len(test_data) + assert writer.buffer == test_data + + await writer.close() + + # After close, data should be written to file + assert temp_file.read_bytes() == test_data + + asyncio.run(run_test()) + + +def test_write_large_data_triggers_flush(tmp_path): + """Test writing data that exceeds buffer size triggers flush.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + buffer_size = 50 + writer = LocalFileWriter(str(temp_file), buffer_size=buffer_size) + await writer.start() + + # Write data larger than buffer size + test_data = b"x" * (buffer_size + 10) + await writer.write(test_data) + + # Buffer should be empty after automatic flush + assert len(writer.buffer) == 0 + + await writer.close() + + # Data should be written to file + assert temp_file.read_bytes() == test_data + + asyncio.run(run_test()) + + +def test_write_multiple_small_chunks(tmp_path): + """Test writing multiple small chunks that eventually trigger flush.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + buffer_size = 50 + writer = LocalFileWriter(str(temp_file), buffer_size=buffer_size) + await writer.start() + + chunk1 = b"a" * 30 + chunk2 = b"b" * 25 # Total: 55 bytes, exceeds buffer + + await writer.write(chunk1) + assert len(writer.buffer) == 30 + + await writer.write(chunk2) + # Should have triggered flush, buffer should be empty + assert len(writer.buffer) == 0 + + await writer.close() + + assert temp_file.read_bytes() == chunk1 + chunk2 + + asyncio.run(run_test()) + + +def test_write_empty_data(tmp_path): + """Test writing empty data.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + await writer.write(b"") + assert len(writer.buffer) == 0 + + await writer.close() + + assert temp_file.read_bytes() == b"" + + asyncio.run(run_test()) + + +def test_write_without_start_graceful_handling(tmp_path): + """Test that writing without calling start() is handled gracefully.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file), buffer_size=10) # Small buffer to force flush + + # This should work fine as long as we don't exceed buffer size + await writer.write(b"small") + assert len(writer.buffer) == 5 + + # When buffer exceeds size, flush is called but does nothing since file_handle is None + # The data stays in buffer instead of being written + await writer.write(b"data that exceeds buffer size") + + # Buffer should contain all the data since flush did nothing + expected_data = b"small" + b"data that exceeds buffer size" + assert writer.buffer == expected_data + + asyncio.run(run_test()) + + +def test_flush_empty_buffer(tmp_path): + """Test flushing when buffer is empty.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + # Flush empty buffer should not raise error + await writer._flush() + assert len(writer.buffer) == 0 + + await writer.close() + + asyncio.run(run_test()) + + +def test_flush_without_file_handle(tmp_path): + """Test flushing without file handle.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + writer.buffer.extend(b"test data") + + # Should not raise error, just do nothing + await writer._flush() + assert len(writer.buffer) == len(b"test data") # Buffer unchanged + + asyncio.run(run_test()) + + +def test_close_flushes_remaining_data(tmp_path): + """Test that close() flushes any remaining buffered data.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file), buffer_size=100) + await writer.start() + + test_data = b"This data should be flushed on close" + await writer.write(test_data) + + # Data should still be in buffer + assert len(writer.buffer) == len(test_data) + + await writer.close() + + # Data should now be written to file + assert temp_file.read_bytes() == test_data + + asyncio.run(run_test()) + + +def test_close_without_start(tmp_path): + """Test closing without calling start().""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + + # Should not raise error + await writer.close() + + asyncio.run(run_test()) + + +def test_close_twice(tmp_path): + """Test calling close() multiple times.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + await writer.close() + + # Second close should not raise error + await writer.close() + + asyncio.run(run_test()) + + +def test_close_handles_flush_exception(tmp_path): + """Test that close() handles exceptions during flush properly.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + # Add some data to buffer + await writer.write(b"test data") + + # Mock flush to raise an exception + with patch.object(writer, '_flush', side_effect=Exception("Flush error")): + with pytest.raises(Exception, match="Flush error"): + await writer.close() + + asyncio.run(run_test()) + + +def test_close_handles_file_close_exception(tmp_path): + """Test that close() handles exceptions during file close.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + # Mock file handle close to raise an exception + with patch.object(writer.file_handle, 'close', side_effect=Exception("Close error")): + with pytest.raises(Exception, match="Close error"): + await writer.close() + + asyncio.run(run_test()) + + +def test_large_file_write(tmp_path): + """Test writing a large amount of data.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file), buffer_size=1024) + await writer.start() + + # Write 1MB of data in chunks + chunk_size = 1024 # Make chunk size same as buffer for exact division + total_size = 1024 * 1024 # 1MB + chunk_data = b"x" * chunk_size + + for _ in range(total_size // chunk_size): + await writer.write(chunk_data) + + await writer.close() + + # Verify file size + assert temp_file.stat().st_size == total_size + + asyncio.run(run_test()) + + +def test_binary_data_integrity(tmp_path): + """Test that binary data is written correctly.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + await writer.start() + + # Create binary data with all byte values + binary_data = bytes(range(256)) + await writer.write(binary_data) + + await writer.close() + + assert temp_file.read_bytes() == binary_data + + asyncio.run(run_test()) + + +def test_concurrent_writes(tmp_path): + """Test concurrent write operations.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file), buffer_size=100) + await writer.start() + + # Create multiple write tasks + async def write_chunk(data): + await writer.write(data) + + tasks = [ + write_chunk(f"chunk{i}".encode() * 10) + for i in range(10) + ] + + await asyncio.gather(*tasks) + await writer.close() + + # Verify file exists and has data + assert temp_file.exists() + assert temp_file.stat().st_size > 0 + + asyncio.run(run_test()) + + +def test_file_permissions_error(tmp_path): + """Test handling of file permission errors.""" + async def run_test(): + # Create a file path in a directory we can't write to + readonly_file = tmp_path / "readonly.txt" + + # Create the file first + readonly_file.write_text("test") + + # Make the file read-only + readonly_file.chmod(0o444) + + writer = LocalFileWriter(str(readonly_file), mode='wb') + + with pytest.raises(PermissionError): + await writer.start() + + asyncio.run(run_test()) + + +def test_nonexistent_directory(): + """Test writing to a file in a nonexistent directory.""" + async def run_test(): + nonexistent_path = "/nonexistent/directory/file.txt" + writer = LocalFileWriter(nonexistent_path) + + with pytest.raises(FileNotFoundError): + await writer.start() + + asyncio.run(run_test()) + + +def test_context_manager_like_usage(tmp_path): + """Test typical usage pattern similar to context manager.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + writer = LocalFileWriter(str(temp_file)) + + try: + await writer.start() + await writer.write(b"Hello, World!") + await writer.write(b" How are you?") + finally: + await writer.close() + + assert temp_file.read_bytes() == b"Hello, World! How are you?" + + asyncio.run(run_test()) + + +def test_buffer_size_edge_cases(tmp_path): + """Test edge cases with different buffer sizes.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + # Test with buffer size of 1 + writer = LocalFileWriter(str(temp_file), buffer_size=1) + await writer.start() + + await writer.write(b"a") # Should trigger flush immediately + assert len(writer.buffer) == 0 + + await writer.write(b"bc") # Should trigger flush after 'b', leaving 'c' + assert len(writer.buffer) == 0 + + await writer.close() + + assert temp_file.read_bytes() == b"abc" + + asyncio.run(run_test()) + + +def test_append_mode(tmp_path): + """Test append mode functionality.""" + async def run_test(): + temp_file = tmp_path / "test.txt" + + # First, write some initial data + temp_file.write_bytes(b"Initial data\n") + + # Now append using LocalFileWriter + writer = LocalFileWriter(str(temp_file), mode='ab') + await writer.start() + + await writer.write(b"Appended data\n") + await writer.close() + + # Verify both pieces of data are present + content = temp_file.read_bytes() + assert content == b"Initial data\nAppended data\n" + + asyncio.run(run_test()) \ No newline at end of file diff --git a/tests/filter_warc/test_s3_utils.py b/tests/filter_warc/test_s3_utils.py new file mode 100644 index 0000000..54c4188 --- /dev/null +++ b/tests/filter_warc/test_s3_utils.py @@ -0,0 +1,177 @@ +import pytest + +import asyncio + + +from cdx_toolkit.filter_warc.s3_utils import ( + _backoff, + parse_s3_uri, + with_retries, +) +from botocore.exceptions import EndpointConnectionError +from botocore.exceptions import ClientError + + +def test_backoff(): + """Test _backoff function with exponential backoff and jitter.""" + base_backoff = 1.0 + tol_float = 0.001 # tolerance for float errors + + # Test attempt 1: should be between 0.8 and 1.2 seconds (with jitter) + result1 = _backoff(1, base_backoff) + assert 0.8 <= result1 <= 1.2 + tol_float + + # Test attempt 2: should be between 1.6 and 2.41 seconds (2^1 * base * jitter) + result2 = _backoff(2, base_backoff) + assert 1.6 <= result2 <= 2.4 + tol_float + + # Test attempt 3: should be between 3.2 and 4.8 seconds (2^2 * base * jitter) + result3 = _backoff(3, base_backoff) + assert 3.2 <= result3 <= 4.8 + tol_float + + # Test with different base backoff + base_backoff_small = 0.1 + result_small = _backoff(1, base_backoff_small) + assert 0.08 <= result_small <= 0.12 + tol_float + + # Test minimum backoff (should never be less than 0.05) + very_small_base = 0.001 + result_min = _backoff(1, very_small_base) + assert result_min >= 0.05 + + # Test that backoff increases with attempts + results = [_backoff(i, 0.5) for i in range(1, 6)] + # Generally should increase, though jitter might cause small variations + # Check that the trend is generally increasing + assert results[1] > results[0] * 0.8 # Allow for jitter variation + assert results[2] > results[1] * 0.8 + assert results[3] > results[2] * 0.8 + + +def test_parse_s3_uri(): + """Test parse_s3_uri function for valid and invalid S3 URIs.""" + + # Test valid S3 URIs + bucket, prefix = parse_s3_uri('s3://my-bucket/path/to/file.txt') + assert bucket == 'my-bucket' + assert prefix == 'path/to/file.txt' + + bucket, prefix = parse_s3_uri('s3://test-bucket/folder/subfolder/data.json') + assert bucket == 'test-bucket' + assert prefix == 'folder/subfolder/data.json' + + bucket, prefix = parse_s3_uri('s3://simple/file') + assert bucket == 'simple' + assert prefix == 'file' + + # Test with deep nested paths + bucket, prefix = parse_s3_uri('s3://bucket/a/b/c/d/e/f/file.ext') + assert bucket == 'bucket' + assert prefix == 'a/b/c/d/e/f/file.ext' + + # Test invalid URIs - should raise ValueError + with pytest.raises(ValueError, match='Not an S3 URI'): + parse_s3_uri('http://example.com/path') + + with pytest.raises(ValueError, match='Not an S3 URI'): + parse_s3_uri('ftp://bucket/file') + + with pytest.raises(ValueError, match='Not an S3 URI'): + parse_s3_uri('bucket/file') + + # Test malformed S3 URIs + with pytest.raises(ValueError, match='Malformed S3 URI'): + parse_s3_uri('s3://') + + with pytest.raises(ValueError, match='Malformed S3 URI'): + parse_s3_uri('s3://bucket') + + with pytest.raises(ValueError, match='Malformed S3 URI'): + parse_s3_uri('s3://bucket/') + + with pytest.raises(ValueError, match='Malformed S3 URI'): + parse_s3_uri('s3:///file') + + +def test_with_retries_success(): + """Test with_retries function with successful operation on first attempt.""" + + async def run_test(): + call_count = 0 + + async def successful_coro(): + nonlocal call_count + call_count += 1 + return 'success' + + result = await with_retries(successful_coro, op_name='test_op', max_attempts=3, base_backoff_seconds=0.1) + + assert result == 'success' + assert call_count == 1 + + asyncio.run(run_test()) + + +def test_with_retries_eventual_success(): + """Test with_retries function that succeeds after initial failures.""" + + async def run_test(): + call_count = 0 + + async def eventually_successful_coro(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ClientError({'Error': {'Code': 'Throttling'}}, 'test_op') + return 'success' + + result = await with_retries( + eventually_successful_coro, + op_name='test_op', + max_attempts=3, + base_backoff_seconds=0.01, # Very short for testing + ) + + assert result == 'success' + assert call_count == 3 + + asyncio.run(run_test()) + + +def test_with_retries_max_attempts_exceeded(): + """Test with_retries function when max attempts are exceeded.""" + + async def run_test(): + call_count = 0 + + async def failing_coro(): + nonlocal call_count + call_count += 1 + raise EndpointConnectionError(endpoint_url='test') + + with pytest.raises(EndpointConnectionError): + await with_retries(failing_coro, op_name='test_op', max_attempts=2, base_backoff_seconds=0.01) + + assert call_count == 2 + + asyncio.run(run_test()) + + +def test_with_retries_non_retryable_exception(): + """Test with_retries function with non-retryable exceptions.""" + + async def run_test(): + call_count = 0 + + async def failing_coro(): + nonlocal call_count + call_count += 1 + raise ValueError('Non-retryable error') + + with pytest.raises(ValueError): + await with_retries(failing_coro, op_name='test_op', max_attempts=3, base_backoff_seconds=0.01) + + # Should fail immediately without retries + assert call_count == 1 + + asyncio.run(run_test()) diff --git a/tests/filter_warc/test_s3_writer.py b/tests/filter_warc/test_s3_writer.py new file mode 100644 index 0000000..c6cb722 --- /dev/null +++ b/tests/filter_warc/test_s3_writer.py @@ -0,0 +1,311 @@ +import pytest + +import asyncio +from unittest.mock import AsyncMock +from unittest.mock import patch + +from cdx_toolkit.filter_warc.s3_writer import mpu_abort + + +from cdx_toolkit.filter_warc.s3_writer import S3ShardWriter + + +def test_shard_writer_start(): + """Test ShardWriter start method.""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_create') as mock_mpu_create: + mock_mpu_create.return_value = 'test-upload-id' + + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=1024, + max_attempts=3, + base_backoff_seconds=0.1, + ) + + await writer.start() + + assert writer.upload_id == 'test-upload-id' + mock_mpu_create.assert_called_once_with( + mock_s3, + 'test-bucket', + 'test.warc.gz', + max_attempts=3, + base_backoff_seconds=0.1, + ) + + asyncio.run(run_test()) + + +def test_shard_writer_write_small_data(): + """Test ShardWriter write method with small data that stays in buffer.""" + + async def run_test(): + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=1024, # 1 KiB + max_attempts=3, + base_backoff_seconds=0.1, + ) + + small_data = b'small test data' + + await writer.write(small_data) + + # Data should be in buffer, no parts uploaded yet + assert len(writer.buffer) == len(small_data) + assert bytes(writer.buffer) == small_data + assert writer.part_number == 1 + assert len(writer.parts) == 0 + + asyncio.run(run_test()) + + +def test_shard_writer_write_large_data(): + """Test ShardWriter write method with large data that triggers part uploads.""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_upload_part') as mock_upload_part: + mock_upload_part.return_value = 'test-etag-1' + + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=100, # 100 bytes + max_attempts=3, + base_backoff_seconds=0.1, + ) + writer.upload_id = 'test-upload-id' + + large_data = b'x' * 250 # 250 bytes, should create 2 parts + + await writer.write(large_data) + + # Should have uploaded 2 parts (100 bytes each) with 50 bytes remaining in buffer + assert mock_upload_part.call_count == 2 + assert len(writer.parts) == 2 + assert writer.part_number == 3 # Next part would be #3 + assert len(writer.buffer) == 50 # Remaining bytes + assert bytes(writer.buffer) == b'x' * 50 + + # Verify parts structure + assert writer.parts[0] == {'PartNumber': 1, 'ETag': 'test-etag-1'} + assert writer.parts[1] == {'PartNumber': 2, 'ETag': 'test-etag-1'} + + asyncio.run(run_test()) + + +def test_shard_writer_flush_full_parts(): + """Test ShardWriter _flush_full_parts private method directly.""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_upload_part') as mock_upload_part: + mock_upload_part.return_value = 'test-etag-flush' + + mock_s3 = AsyncMock() + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=50, # 50 bytes + max_attempts=3, + base_backoff_seconds=0.1, + ) + writer.upload_id = 'test-upload-id' + + # Pre-fill buffer with 150 bytes (should create 3 parts of 50 bytes each) + writer.buffer.extend(b'a' * 150) + + await writer._flush_full_parts() + + # Should have uploaded 3 full parts, no remainder + assert mock_upload_part.call_count == 3 + assert len(writer.parts) == 3 + assert writer.part_number == 4 # Next part would be #4 + assert len(writer.buffer) == 0 # All data flushed + + # Verify all parts were created correctly + for i in range(3): + assert writer.parts[i] == {'PartNumber': i + 1, 'ETag': 'test-etag-flush'} + + asyncio.run(run_test()) + + +def test_shard_writer_close_with_buffer(): + """Test ShardWriter close method with data remaining in buffer.""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_upload_part') as mock_upload_part, patch( + 'cdx_toolkit.filter_warc.s3_writer.mpu_complete' + ) as mock_complete: + mock_upload_part.return_value = 'final-etag' + + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=1000, # Large min size to keep data in buffer + max_attempts=3, + base_backoff_seconds=0.1, + ) + writer.upload_id = 'test-upload-id' + + # Add some data to buffer + remaining_data = b'final chunk data' + writer.buffer.extend(remaining_data) + + await writer.close() + + # Should upload the final part and complete MPU + mock_upload_part.assert_called_once_with( + mock_s3, + 'test-bucket', + 'test.warc.gz', + 'test-upload-id', + 1, # part number + remaining_data, + 3, # max attempts + 0.1, # base backoff + ) + + mock_complete.assert_called_once_with( + mock_s3, + 'test-bucket', + 'test.warc.gz', + 'test-upload-id', + [{'PartNumber': 1, 'ETag': 'final-etag'}], + 3, # max attempts + 0.1, # base backoff + ) + + # Buffer should be cleared + assert len(writer.buffer) == 0 + assert len(writer.parts) == 1 + + asyncio.run(run_test()) + + +def test_shard_writer_close_empty(): + """Test ShardWriter close method with no data (empty buffer, no parts).""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_upload_part') as mock_upload_part, patch( + 'cdx_toolkit.filter_warc.s3_writer.mpu_complete' + ) as mock_complete: + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=1000, + max_attempts=3, + base_backoff_seconds=0.1, + ) + writer.upload_id = 'test-upload-id' + + # No data in buffer, no parts uploaded + await writer.close() + + # Should not upload any parts or complete MPU since there's no data + mock_upload_part.assert_not_called() + mock_complete.assert_not_called() + + # State should remain unchanged + assert len(writer.buffer) == 0 + assert len(writer.parts) == 0 + + asyncio.run(run_test()) + + +def test_shard_writer_close_with_exception(): + """Test ShardWriter close method with exception and abort handling.""" + + async def run_test(): + with patch('cdx_toolkit.filter_warc.s3_writer.mpu_upload_part') as mock_upload_part, patch( + 'cdx_toolkit.filter_warc.s3_writer.mpu_complete' + ) as mock_complete, patch('cdx_toolkit.filter_warc.s3_writer.mpu_abort') as mock_abort: + mock_upload_part.return_value = 'error-etag' + mock_complete.side_effect = Exception('Complete failed') + + mock_s3 = AsyncMock() + + writer = S3ShardWriter( + s3_client=mock_s3, + shard_key='test.warc.gz', + dest_bucket='test-bucket', + content_type='application/gzip', + min_part_size=1000, + max_attempts=3, + base_backoff_seconds=0.1, + ) + writer.upload_id = 'test-upload-id' + + # Add some data to buffer to trigger upload and complete + writer.buffer.extend(b'some data') + + # Should raise the exception after attempting abort + with pytest.raises(Exception, match='Complete failed'): + await writer.close() + + # Should have attempted to upload part and complete, then abort on failure + mock_upload_part.assert_called_once() + mock_complete.assert_called_once() + mock_abort.assert_called_once_with(mock_s3, 'test-bucket', 'test.warc.gz', 'test-upload-id') + + asyncio.run(run_test()) + + +def test_mpu_abort_success(): + """Test mpu_abort function with successful abort.""" + + async def run_test(): + mock_s3 = AsyncMock() + bucket = 'test-bucket' + key = 'test-key' + upload_id = 'test-upload-id' + + await mpu_abort(mock_s3, bucket, key, upload_id) + + mock_s3.abort_multipart_upload.assert_called_once_with(Bucket=bucket, Key=key, UploadId=upload_id) + + asyncio.run(run_test()) + + +def test_mpu_abort_with_exception(): + """Test mpu_abort function when abort fails (should catch exception).""" + + async def run_test(): + mock_s3 = AsyncMock() + mock_s3.abort_multipart_upload.side_effect = Exception('S3 error') + + bucket = 'test-bucket' + key = 'test-key' + upload_id = 'test-upload-id' + + # Should not raise exception, should log it instead + await mpu_abort(mock_s3, bucket, key, upload_id) + + mock_s3.abort_multipart_upload.assert_called_once_with(Bucket=bucket, Key=key, UploadId=upload_id) + + asyncio.run(run_test()) diff --git a/tests/filter_warc/test_warc_filter.py b/tests/filter_warc/test_warc_filter.py new file mode 100644 index 0000000..937776e --- /dev/null +++ b/tests/filter_warc/test_warc_filter.py @@ -0,0 +1,434 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock, patch +from cdx_toolkit.filter_warc.data_classes import ThroughputTracker +from tests.conftest import TEST_DATA_PATH + +from cdx_toolkit.filter_warc.warc_filter import WARCFilter + +fixture_path = TEST_DATA_PATH / 'warc_by_cdx' + + +def test_filter_keyboard_interrupt_handling(caplog): + """Test that KeyboardInterrupt is properly handled in the filter method.""" + import logging + + # Set log level to capture WARNING messages + caplog.set_level(logging.WARNING, logger='cdx_toolkit.filter_warc.warc_filter') + + warc_filter = WARCFilter(cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}) + + # Mock filter_async to raise KeyboardInterrupt + with patch.object(warc_filter, 'filter_async', side_effect=KeyboardInterrupt('Simulated user interrupt')): + # Call the filter method + result = warc_filter.filter() + + # Should return -1 when interrupted + assert result == -1 + + # Should log the warning message + assert 'Interrupted by user.' in caplog.text + + +def test_rotate_files_no_rotation_needed(): + """Test rotate_files when no rotation is needed (file size below limit).""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=1000, # 1KB limit + ) + + mock_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 500 # 500 bytes + added_byte_size = 200 # Adding 200 bytes, total would be 700 (below limit) + + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should return original values since no rotation occurred + assert result_writer == mock_writer + assert result_sequence == current_file_sequence + assert result_size == current_file_size + + # Writer should not be closed + mock_writer.close.assert_not_called() + + asyncio.run(run_test()) + + +def test_rotate_files_rotation_needed_without_resource_records(): + """Test rotate_files when rotation is needed and no resource records to write.""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=1000, # 1KB limit + write_paths_as_resource_records=None, # No resource records + ) + + mock_writer = AsyncMock() + mock_new_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 800 # 800 bytes + added_byte_size = 300 # Adding 300 bytes, total would be 1100 (above limit) + + # Mock create_new_writer_with_header + with patch('cdx_toolkit.filter_warc.warc_filter.create_new_writer_with_header') as mock_create: + mock_create.return_value = (mock_new_writer, 150, 'warcinfo-123') # (writer, header_size, warcinfo_id) + + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should have rotated + assert result_writer == mock_new_writer + assert result_sequence == current_file_sequence + 1 # Incremented + assert result_size == 150 # Header size only + + # Old writer should be closed + mock_writer.close.assert_called_once() + + # New writer should be created + mock_create.assert_called_once_with( + sequence=current_file_sequence + 1, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + asyncio.run(run_test()) + + +def test_rotate_files_rotation_needed_with_resource_records(): + """Test rotate_files when rotation is needed and resource records need to be written.""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=1000, # 1KB limit + write_paths_as_resource_records=['/fake/resource1.txt', '/fake/resource2.txt'], + ) + + mock_writer = AsyncMock() + mock_new_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 800 # 800 bytes + added_byte_size = 300 # Adding 300 bytes, total would be 1100 (above limit) + + # Mock create_new_writer_with_header + with patch('cdx_toolkit.filter_warc.warc_filter.create_new_writer_with_header') as mock_create: + mock_create.return_value = (mock_new_writer, 150, 'warcinfo-123') + + # Mock write_resource_records + with patch.object(warc_filter, 'write_resource_records', return_value=75) as mock_write_resources: + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should have rotated + assert result_writer == mock_new_writer + assert result_sequence == current_file_sequence + 1 + assert result_size == 150 + 75 # Header size + resource records size + + # Old writer should be closed + mock_writer.close.assert_called_once() + + # New writer should be created + mock_create.assert_called_once() + + # Resource records should be written + mock_write_resources.assert_called_once_with(mock_new_writer, warcinfo_id='warcinfo-123') + + asyncio.run(run_test()) + + +def test_rotate_files_no_max_file_size_set(): + """Test rotate_files when max_file_size is not set (None).""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=None, # No limit + ) + + mock_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 999999999 # Very large file + added_byte_size = 999999999 # Very large addition + + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should not rotate regardless of size + assert result_writer == mock_writer + assert result_sequence == current_file_sequence + assert result_size == current_file_size + + # Writer should not be closed + mock_writer.close.assert_not_called() + + asyncio.run(run_test()) + + +def test_rotate_files_edge_case_exact_limit(): + """Test rotate_files when the total size exactly equals the limit.""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=1000, # 1KB limit + ) + + mock_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 700 # 700 bytes + added_byte_size = 300 # Adding 300 bytes, total would be exactly 1000 + + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should not rotate when exactly at limit (only rotate when > limit) + assert result_writer == mock_writer + assert result_sequence == current_file_sequence + assert result_size == current_file_size + + # Writer should not be closed + mock_writer.close.assert_not_called() + + asyncio.run(run_test()) + + +def test_rotate_files_edge_case_just_over_limit(): + """Test rotate_files when the total size is just 1 byte over the limit.""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + max_file_size=1000, # 1KB limit + ) + + mock_writer = AsyncMock() + mock_new_writer = AsyncMock() + current_file_sequence = 1 + current_file_size = 700 # 700 bytes + added_byte_size = 301 # Adding 301 bytes, total would be 1001 (1 byte over) + + # Mock create_new_writer_with_header + with patch('cdx_toolkit.filter_warc.warc_filter.create_new_writer_with_header') as mock_create: + mock_create.return_value = (mock_new_writer, 150, 'warcinfo-123') + + # Call rotate_files + result_writer, result_sequence, result_size = await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=current_file_sequence, + current_file_size=current_file_size, + added_byte_size=added_byte_size, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Should rotate when just over limit + assert result_writer == mock_new_writer + assert result_sequence == current_file_sequence + 1 + assert result_size == 150 + + # Old writer should be closed + mock_writer.close.assert_called_once() + + asyncio.run(run_test()) + + +def test_rotate_files_kwargs_passed_through(): + """Test that all kwargs are properly passed to create_new_writer_with_header.""" + + async def run_test(): + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + ) + + mock_writer = AsyncMock() + mock_new_writer = AsyncMock() + + # Mock create_new_writer_with_header + with patch('cdx_toolkit.filter_warc.warc_filter.create_new_writer_with_header') as mock_create: + mock_create.return_value = (mock_new_writer, 150, 'warcinfo-123') + + # Call rotate_files with various kwargs + await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=1, + current_file_size=800, + added_byte_size=300, + writer_id=99, + output_path_prefix='/custom/output', + max_attempts=5, + base_backoff_seconds=2.5, + min_part_size=2048, + writer_info={'custom': 'info'}, + warc_version='1.1', + gzip=True, + custom_param='custom_value', + ) + + # Verify all kwargs are passed through + mock_create.assert_called_once_with( + sequence=2, # incremented from 1 + writer_id=99, + output_path_prefix='/custom/output', + max_attempts=5, + base_backoff_seconds=2.5, + min_part_size=2048, + writer_info={'custom': 'info'}, + warc_version='1.1', + gzip=True, + custom_param='custom_value', + ) + + asyncio.run(run_test()) + + +def test_rotate_files_logging(caplog): + """Test that rotation logs the appropriate message.""" + import logging + + async def run_test(): + # Set log level to capture INFO messages + caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') + + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + ) + + mock_writer = AsyncMock() + mock_new_writer = AsyncMock() + + # Mock create_new_writer_with_header + with patch('cdx_toolkit.filter_warc.warc_filter.create_new_writer_with_header') as mock_create: + mock_create.return_value = (mock_new_writer, 150, 'warcinfo-123') + + # Call rotate_files to trigger rotation + await warc_filter.rotate_files( + writer=mock_writer, + current_file_sequence=5, + current_file_size=800, + added_byte_size=300, + writer_id=1, + output_path_prefix='/fake/output', + max_attempts=3, + base_backoff_seconds=1.0, + min_part_size=1024, + writer_info={'writer_id': 1}, + ) + + # Check that the rotation log message was written + assert 'Rotated to new WARC file sequence 6 due to size limit' in caplog.text + + asyncio.run(run_test()) + + +def test_log_writer(caplog): + """Test log writer.""" + + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + log_every_n=2, + ) + tracker = ThroughputTracker() + warc_filter.log_writer(1, 0, tracker) + warc_filter.log_writer(1, 1, tracker) + warc_filter.log_writer(1, 2, tracker) + + assert caplog.text.count('WARC Writer 1') == 2 + + +def test_log_reader(caplog): + """Test log reader.""" + + warc_filter = WARCFilter( + cdx_paths=['/fake/path'], + prefix_path='/fake/prefix', + writer_info={'writer_id': 1}, + log_every_n=2, + ) + tracker = ThroughputTracker() + warc_filter.log_reader(1, 0, tracker) + warc_filter.log_reader(1, 1, tracker) + warc_filter.log_reader(1, 2, tracker) + + assert caplog.text.count('WARC Reader 1') == 2 diff --git a/tests/filter_warc/test_warc_from_fs.py b/tests/filter_warc/test_warc_from_fs.py new file mode 100644 index 0000000..ccceaf0 --- /dev/null +++ b/tests/filter_warc/test_warc_from_fs.py @@ -0,0 +1,53 @@ +from tests.conftest import requires_aws_s3 +from cdx_toolkit.warc import fetch_warc_record + + +def test_fetch_warc_record_from_http(): + encoding = 'utf-8' + capture = { + 'url': 'https://bibliotheque.missiondefrance.fr/index.php?lvl=bulletin_display&id=319', + 'mime': 'text/html', + 'mime-detected': 'application/xhtml+xml', + 'status': '200', + 'digest': 'D5K3FUWDRAOMMTJC2CTWV7L2ABFIJ5BP', + 'length': '9754', + 'offset': '111440525', + 'filename': 'crawl-data/CC-MAIN-2024-30/segments/1720763514759.37/warc/CC-MAIN-20240716142214-20240716172214-00337.warc.gz', + 'charset': 'UTF-8', + 'languages': 'fra', + 'timestamp': '20240716153155', + } + warc_download_prefix = 'https://data.commoncrawl.org' + + record = fetch_warc_record(capture, warc_download_prefix) + record_content = record.content_stream().read().decode(encoding, errors='ignore') + + assert record.rec_type == 'response' + assert record.length == 75825 + assert 'Catalogue en ligne Mission de France' in record_content + + +@requires_aws_s3 +def test_fetch_warc_record_from_s3(): + encoding = 'utf-8' + capture = { + 'url': 'https://bibliotheque.missiondefrance.fr/index.php?lvl=bulletin_display&id=319', + 'mime': 'text/html', + 'mime-detected': 'application/xhtml+xml', + 'status': '200', + 'digest': 'D5K3FUWDRAOMMTJC2CTWV7L2ABFIJ5BP', + 'length': '9754', + 'offset': '111440525', + 'filename': 'crawl-data/CC-MAIN-2024-30/segments/1720763514759.37/warc/CC-MAIN-20240716142214-20240716172214-00337.warc.gz', + 'charset': 'UTF-8', + 'languages': 'fra', + 'timestamp': '20240716153155', + } + warc_download_prefix = 's3://commoncrawl' + + record = fetch_warc_record(capture, warc_download_prefix) + record_content = record.content_stream().read().decode(encoding, errors='ignore') + + assert record.rec_type == 'response' + assert record.length == 75825 + assert 'Catalogue en ligne Mission de France' in record_content diff --git a/tests/filter_warc/test_warc_utils.py b/tests/filter_warc/test_warc_utils.py new file mode 100644 index 0000000..9cda87b --- /dev/null +++ b/tests/filter_warc/test_warc_utils.py @@ -0,0 +1,33 @@ +import pytest +from cdx_toolkit.filter_warc.warc_utils import get_resource_record_from_path +from tests.conftest import TEST_DATA_PATH + + +def test_get_resource_record_from_path(): + resource_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + record = get_resource_record_from_path(resource_path, warcinfo_id="abc123") + + assert record.content_type == 'text/plain' + + record_headers = dict(record.rec_headers.headers) + assert record_headers['WARC-Target-URI'] == str(resource_path) + assert record_headers["WARC-Warcinfo-ID"] == "abc123" + + +def test_get_resource_record_from_path_with_metadata(): + resource_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' + metadata_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.metadata.json' + + record = get_resource_record_from_path(resource_path, metadata_path=metadata_path, warcinfo_id="abc123") + + assert record.content_type == 'application/cdx' + + record_headers = dict(record.rec_headers.headers) + assert record_headers['WARC-Target-URI'] == 'filter_cdx.cdx.gz' + assert record_headers["WARC-Warcinfo-ID"] == "abc123" + + +def test_get_resource_record_from_path_with_invalid_metadata_path(): + with pytest.raises(ValueError): + resource_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + get_resource_record_from_path(resource_path, metadata_path='invalid_metadata.xy', warcinfo_id="abc123") diff --git a/tests/filter_warc/test_warc_writer.py b/tests/filter_warc/test_warc_writer.py new file mode 100644 index 0000000..19e5f19 --- /dev/null +++ b/tests/filter_warc/test_warc_writer.py @@ -0,0 +1,142 @@ +from io import BytesIO +import fsspec +import pytest +import cdx_toolkit + +from tests.conftest import TEST_DATA_PATH, requires_aws_s3 + +from warcio import WARCWriter +from warcio.archiveiterator import ArchiveIterator + + +fixture_path = TEST_DATA_PATH / 'warc_by_cdx' + + +@pytest.mark.parametrize( + 'prefix,gzip', + [ + pytest.param('test-prefix', False, id='File name prefix on local'), + pytest.param('test-prefix', True, id='File name prefix on local with gzip'), + # raised FileNotFound error (parent dir does not exist) + # pytest.param("test-prefix-folder/file-prefix", None, id="Folder as prefix"), + ], +) +def test_write_to_local(prefix, gzip, tmpdir): + info = { + 'software': 'pypi_cdx_toolkit/test', + 'description': 'test', + 'format': 'WARC file version 1.0', + } + encoding = 'utf-8' + full_prefix = str(tmpdir) + '/' + prefix + fs, fs_prefix_path = fsspec.url_to_fs(full_prefix) + + writer = cdx_toolkit.warc.get_writer(full_prefix, None, info, gzip=gzip) + + # single record + input_resource_record_text = 'foo bar text' + writer.write_record( + WARCWriter(None).create_warc_record( + uri='foo/bar', + record_type='resource', + payload=BytesIO(input_resource_record_text.encode(encoding)), + warc_content_type='text/plain', + ) + ) + writer.close() + + # Check that WARC file was created + warc_path = fs_prefix_path + '-000000.extracted.warc' + if gzip: + warc_path += '.gz' + + assert fs.exists(warc_path) + + # Validate that creator/operator are not in warcinfo record + info_record = None + resource_record = None + with open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode(encoding) + + if record.rec_type == 'resource': + resource_record = record.content_stream().read().decode(encoding) + break + + assert resource_record is not None + assert info_record is not None + + assert 'description: test' in info_record + assert resource_record == input_resource_record_text + + +@requires_aws_s3 +def test_write_to_s3(s3_tmpdir): + info = { + 'software': 'pypi_cdx_toolkit/test', + 'description': 'test', + 'format': 'WARC file version 1.0', + } + encoding = 'utf-8' + + fs, fs_prefix_path = fsspec.url_to_fs(s3_tmpdir) + + writer = cdx_toolkit.warc.get_writer(s3_tmpdir, None, info) + + # single record + input_resource_record_text = 'foo bar text' + writer.write_record( + WARCWriter(None).create_warc_record( + uri='foo/bar', + record_type='resource', + payload=BytesIO(input_resource_record_text.encode(encoding)), + warc_content_type='text/plain', + ) + ) + writer.close() + + # Check that WARC file was created + warc_path = fs_prefix_path + '-000000.extracted.warc.gz' + assert fs.exists(warc_path) + + # Validate that creator/operator are not in warcinfo record + info_record = None + resource_record = None + with fs.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'warcinfo': + info_record = record.content_stream().read().decode(encoding) + + if record.rec_type == 'resource': + resource_record = record.content_stream().read().decode(encoding) + break + + assert resource_record is not None + assert info_record is not None + + assert 'description: test' in info_record + assert resource_record == input_resource_record_text + + +def test_warc_info(): + warc_version = '1.0' + gzip = False + file_handler = BytesIO() + filename = 'foo.warc' + + info = { + 'software': 'pypi_cdx_toolkit/123', + 'isPartOf': 'bar', + 'description': 'warc extraction based on CDX generated with: xx', + 'format': 'WARC file version 1.0', + } + + writer = WARCWriter(file_handler, gzip=gzip, warc_version=warc_version) + warcinfo = writer.create_warcinfo_record(filename, info) + + writer.write_record(warcinfo) + + file_value = file_handler.getvalue().decode('utf-8') + + assert 'pypi_cdx_toolkit/123' in file_value diff --git a/tests/unit/test_warc.py b/tests/unit/test_warc.py index e5df43f..e2474ff 100644 --- a/tests/unit/test_warc.py +++ b/tests/unit/test_warc.py @@ -1,6 +1,5 @@ import cdx_toolkit.warc - def test_wb_redir_to_original(): location = 'https://web.archive.org/web/20110209062054id_/http://commoncrawl.org/' ret = 'http://commoncrawl.org/'