Skip to content

Fix catalog identifier matching to exact match #2732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions crates/persistence/src/backend/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,14 +1251,9 @@ impl ParquetDataCatalog {
let start_u64 = start.map(|s| s.as_u64());
let end_u64 = end.map(|e| e.as_u64());

let safe_ids = instrument_ids.as_ref().map(|ids| {
ids.iter()
.map(|id| urisafe_instrument_id(id))
.collect::<Vec<String>>()
});

let base_dir = self.make_path(data_cls, None)?;

// Use recursive listing to match Python's glob behavior
let list_result = self.execute_async(async {
let prefix = ObjectPath::from(format!("{base_dir}/"));
let mut stream = self.object_store.list(Some(&prefix));
Expand All @@ -1269,23 +1264,68 @@ impl ParquetDataCatalog {
Ok::<Vec<_>, anyhow::Error>(objects)
})?;

for object in list_result {
let path_str = object.location.to_string();
if path_str.ends_with(".parquet") {
if let Some(ids) = &safe_ids {
let matches_any_id = ids.iter().any(|safe_id| path_str.contains(safe_id));
if !matches_any_id {
continue;
}
let mut file_paths: Vec<String> = list_result
.into_iter()
.filter_map(|object| {
let path_str = object.location.to_string();
if path_str.ends_with(".parquet") {
Some(path_str)
} else {
None
}
})
.collect();

if query_intersects_filename(&path_str, start_u64, end_u64) {
let full_uri = self.reconstruct_full_uri(&path_str);
files.push(full_uri);
}
// Apply identifier filtering if provided
if let Some(identifiers) = instrument_ids {
let safe_identifiers: Vec<String> = identifiers
.iter()
.map(|id| urisafe_instrument_id(id))
.collect();

// Exact match by default for instrument_ids or bar_types
let exact_match_file_paths: Vec<String> = file_paths
.iter()
.filter(|file_path| {
// Extract the directory name (second to last path component)
let path_parts: Vec<&str> = file_path.split('/').collect();
if path_parts.len() >= 2 {
let dir_name = path_parts[path_parts.len() - 2];
safe_identifiers.iter().any(|safe_id| safe_id == dir_name)
} else {
false
}
})
.cloned()
.collect();

if exact_match_file_paths.is_empty() && data_cls == "bars" {
// Partial match of instrument_ids in bar_types for bars
file_paths.retain(|file_path| {
let path_parts: Vec<&str> = file_path.split('/').collect();
if path_parts.len() >= 2 {
let dir_name = path_parts[path_parts.len() - 2];
safe_identifiers
.iter()
.any(|safe_id| dir_name.starts_with(&format!("{}-", safe_id)))
} else {
false
}
});
} else {
file_paths = exact_match_file_paths;
}
}

// Apply timestamp filtering
file_paths.retain(|file_path| query_intersects_filename(file_path, start_u64, end_u64));

// Convert to full URIs
for file_path in file_paths {
let full_uri = self.reconstruct_full_uri(&file_path);
files.push(full_uri);
}

Ok(files)
}

Expand Down
60 changes: 41 additions & 19 deletions nautilus_trader/persistence/catalog/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ def query(
An additional SQL WHERE clause to filter the data (used in Rust queries).
files : list[str], optional
A specific list of files to query from. If provided, these files are used
instead of discovering files through the normal process. Forces PyArrow backend.
instead of discovering files through the normal process.
**kwargs : Any
Additional keyword arguments passed to the underlying query implementation.

Expand Down Expand Up @@ -1233,6 +1233,7 @@ def query(
start=start,
end=end,
where=where,
files=files,
**kwargs,
)
else:
Expand Down Expand Up @@ -1270,6 +1271,7 @@ def _query_rust(
start: TimestampLike | None = None,
end: TimestampLike | None = None,
where: str | None = None,
files: list[str] | None = None,
**kwargs: Any,
) -> list[Data]:
query_data_cls = OrderBookDelta if data_cls == OrderBookDeltas else data_cls
Expand All @@ -1279,6 +1281,7 @@ def _query_rust(
start=start,
end=end,
where=where,
file=files,
**kwargs,
)
result = session.to_query_result()
Expand All @@ -1304,6 +1307,7 @@ def backend_session(
end: TimestampLike | None = None,
where: str | None = None,
session: DataBackendSession | None = None,
files: list[str] | None = None,
**kwargs: Any,
) -> DataBackendSession:
"""
Expand All @@ -1327,6 +1331,9 @@ def backend_session(
An additional SQL WHERE clause to filter the data.
session : DataBackendSession, optional
An existing session to update. If None, a new session is created.
files : list[str], optional
A specific list of files to query from. If provided, these files are used
instead of discovering files through the normal process.
**kwargs : Any
Additional keyword arguments.

Expand All @@ -1351,7 +1358,7 @@ def backend_session(

"""
data_type: NautilusDataType = ParquetDataCatalog._nautilus_data_cls_to_data_type(data_cls)
files = self._query_files(data_cls, identifiers, start, end)
file_list = files if files else self._query_files(data_cls, identifiers, start, end)
file_prefix = class_to_filename(data_cls)

if session is None:
Expand All @@ -1361,7 +1368,7 @@ def backend_session(
if self.fs_protocol != "file":
self._register_object_store_with_session(session)

for idx, file in enumerate(files):
for idx, file in enumerate(file_list):
table = f"{file_prefix}_{idx}"
query = self._build_query(
table,
Expand Down Expand Up @@ -1492,10 +1499,7 @@ def _query_pyarrow(
**kwargs: Any,
) -> list[Data]:
# Load dataset - use provided files or query for them
if files is not None:
file_list = files
else:
file_list = self._query_files(data_cls, identifiers, start, end)
file_list = files if files else self._query_files(data_cls, identifiers, start, end)

if not file_list:
return []
Expand Down Expand Up @@ -1536,32 +1540,50 @@ def _query_files(
file_prefix = class_to_filename(data_cls)
base_path = self.path.rstrip("/")
glob_path = f"{base_path}/data/{file_prefix}/**/*.parquet"
file_names: list[str] = self.fs.glob(glob_path)
file_paths: list[str] = self.fs.glob(glob_path)

if identifiers:
if not isinstance(identifiers, list):
identifiers = [identifiers]

safe_identifiers = [urisafe_identifier(identifier) for identifier in identifiers]
file_names = [
file_name
for file_name in file_names
if any(safe_identifier in file_name for safe_identifier in safe_identifiers)

# Exact match by default for instrument_ids or bar_types
exact_match_file_paths = [
file_path
for file_path in file_paths
if any(
safe_identifier == file_path.split("/")[-2]
for safe_identifier in safe_identifiers
)
]

if not exact_match_file_paths and data_cls in [Bar, *Bar.__subclasses__()]:
# Partial match of instrument_ids in bar_types for bars
file_paths = [
file_path
for file_path in file_paths
if any(
file_path.split("/")[-2].startswith(f"{safe_identifier}-")
for safe_identifier in safe_identifiers
)
]
else:
file_paths = exact_match_file_paths

used_start: pd.Timestamp | None = time_object_to_dt(start)
used_end: pd.Timestamp | None = time_object_to_dt(end)
file_names = [
file_name
for file_name in file_names
if _query_intersects_filename(file_name, used_start, used_end)
file_paths = [
file_path
for file_path in file_paths
if _query_intersects_filename(file_path, used_start, used_end)
]

if self.show_query_paths:
for file_name in file_names:
print(file_name)
for file_path in file_paths:
print(file_path)

return file_names
return file_paths

@staticmethod
def _handle_table_nautilus(
Expand Down
Loading