From c72f16abc4361b4a606a1c48bffd487730271fab Mon Sep 17 00:00:00 2001 From: Shrey Pant Date: Thu, 13 Nov 2025 19:10:44 +0530 Subject: [PATCH 01/48] feat(cli): Add Podman Support --- helix-cli/src/commands/add.rs | 1 + helix-cli/src/commands/build.rs | 7 +- helix-cli/src/commands/delete.rs | 3 +- helix-cli/src/commands/integrations/ecr.rs | 2 +- helix-cli/src/commands/migrate.rs | 4 +- helix-cli/src/commands/prune.rs | 42 +-- helix-cli/src/commands/push.rs | 2 +- helix-cli/src/commands/start.rs | 2 +- helix-cli/src/commands/status.rs | 3 +- helix-cli/src/commands/stop.rs | 2 +- helix-cli/src/config.rs | 32 ++ helix-cli/src/docker.rs | 414 +++++++++++++-------- helix-cli/src/tests/check_tests.rs | 2 + 13 files changed, 332 insertions(+), 184 deletions(-) diff --git a/helix-cli/src/commands/add.rs b/helix-cli/src/commands/add.rs index a85089e87..830f7f898 100644 --- a/helix-cli/src/commands/add.rs +++ b/helix-cli/src/commands/add.rs @@ -123,6 +123,7 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { port: None, // Let the system assign a port build_mode: BuildMode::Debug, db_config: DbConfig::default(), + }; project_context diff --git a/helix-cli/src/commands/build.rs b/helix-cli/src/commands/build.rs index 4eb7149e7..cca93c3d3 100644 --- a/helix-cli/src/commands/build.rs +++ b/helix-cli/src/commands/build.rs @@ -86,8 +86,10 @@ pub async fn run(instance_name: String, metrics_sender: &MetricsSender) -> Resul // For local instances, build Docker image if instance_config.should_build_docker_image() { + let runtime = project.config.project.container_runtime; + DockerManager::check_runtime_available(runtime)?; let docker = DockerManager::new(&project); - DockerManager::check_docker_available()?; + docker.build_image(&instance_name, instance_config.docker_build_target())?; } @@ -290,10 +292,11 @@ async fn generate_docker_files( return Ok(()); } - print_status("DOCKER", "Generating Docker configuration..."); + let docker = DockerManager::new(project); + print_status(docker.runtime.label(), "Generating configuration..."); // Generate Dockerfile let dockerfile_content = docker.generate_dockerfile(instance_name, instance_config.clone())?; let dockerfile_path = project.dockerfile_path(instance_name); diff --git a/helix-cli/src/commands/delete.rs b/helix-cli/src/commands/delete.rs index 610ccadf1..ab7f4f109 100644 --- a/helix-cli/src/commands/delete.rs +++ b/helix-cli/src/commands/delete.rs @@ -37,7 +37,8 @@ pub async fn run(instance_name: String) -> Result<()> { print_status("DELETE", &format!("Deleting instance '{instance_name}'")); // Stop and remove Docker containers and volumes - if DockerManager::check_docker_available().is_ok() { + let runtime = project.config.project.container_runtime; + if DockerManager::check_runtime_available(runtime).is_ok() { let docker = DockerManager::new(&project); // Remove containers and Docker volumes diff --git a/helix-cli/src/commands/integrations/ecr.rs b/helix-cli/src/commands/integrations/ecr.rs index 4a12c0769..f31d2e9d7 100644 --- a/helix-cli/src/commands/integrations/ecr.rs +++ b/helix-cli/src/commands/integrations/ecr.rs @@ -343,7 +343,7 @@ impl<'a> EcrManager<'a> { .to_string(); use tokio::io::AsyncWriteExt; - let mut login_cmd = tokio::process::Command::new("docker"); + let mut login_cmd = tokio::process::Command::new(docker.runtime.binary()); login_cmd.args([ "login", "--username", diff --git a/helix-cli/src/commands/migrate.rs b/helix-cli/src/commands/migrate.rs index 6b7707600..33f9777d5 100644 --- a/helix-cli/src/commands/migrate.rs +++ b/helix-cli/src/commands/migrate.rs @@ -1,5 +1,5 @@ use crate::config::{ - BuildMode, DbConfig, GraphConfig, HelixConfig, LocalInstanceConfig, ProjectConfig, VectorConfig, + BuildMode, DbConfig, GraphConfig, HelixConfig, LocalInstanceConfig, ProjectConfig, VectorConfig, ContainerRuntime }; use crate::errors::{CliError, project_error}; use crate::utils::{ @@ -415,6 +415,7 @@ fn create_v2_config(ctx: &MigrationContext) -> Result<()> { port: Some(ctx.port), build_mode: BuildMode::Debug, db_config, + }; // Create local instances map @@ -425,6 +426,7 @@ fn create_v2_config(ctx: &MigrationContext) -> Result<()> { let project_config = ProjectConfig { name: ctx.project_name.clone(), queries: PathBuf::from(&ctx.queries_dir), + container_runtime: ContainerRuntime::Docker, }; // Create final helix config diff --git a/helix-cli/src/commands/prune.rs b/helix-cli/src/commands/prune.rs index ef9610586..8becd2186 100644 --- a/helix-cli/src/commands/prune.rs +++ b/helix-cli/src/commands/prune.rs @@ -5,6 +5,7 @@ use crate::utils::{ print_confirm, print_lines, print_newline, print_status, print_success, print_warning, }; use eyre::Result; +use crate::config::ContainerRuntime; pub async fn run(instance: Option, all: bool) -> Result<()> { // Try to load project context @@ -38,7 +39,8 @@ async fn prune_instance(project: &ProjectContext, instance_name: &str) -> Result let _instance_config = project.config.get_instance(instance_name)?; // Check Docker availability - if DockerManager::check_docker_available().is_ok() { + let runtime = project.config.project.container_runtime; + if DockerManager::check_runtime_available(runtime).is_ok() { let docker = DockerManager::new(project); // Remove containers (but not volumes) @@ -72,8 +74,8 @@ async fn prune_all_instances(project: &ProjectContext) -> Result<()> { } print_status("PRUNE", &format!("Found {} instance(s) to prune", instances.len())); - - if DockerManager::check_docker_available().is_ok() { + let runtime = project.config.project.container_runtime; + if DockerManager::check_runtime_available(runtime).is_ok() { let docker = DockerManager::new(project); for instance_name in &instances { @@ -115,10 +117,10 @@ async fn prune_unused_resources(project: &ProjectContext) -> Result<()> { " To clean a specific instance, use 'helix prune '", ]); print_newline(); - + let runtime = project.config.project.container_runtime; // Check Docker availability print_status("PRUNE", "Checking Docker availability"); - DockerManager::check_docker_available()?; + DockerManager::check_runtime_available(runtime)?; print_status("PRUNE", "Running Docker system cleanup"); // Use centralized docker command @@ -156,22 +158,20 @@ async fn prune_system_wide() -> Result<()> { } print_status("PRUNE", "Pruning all Helix images from system"); - - // Check Docker availability - DockerManager::check_docker_available()?; - - // Remove all Helix images - DockerManager::clean_all_helix_images()?; - - // Also clean unused Docker resources - let output = std::process::Command::new("docker") - .args(["system", "prune", "-f"]) - .output() - .map_err(|e| eyre::eyre!("Failed to run docker system prune: {e}"))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - return Err(eyre::eyre!("Failed to prune Docker resources:\n{stderr}")); + for runtime in [ContainerRuntime::Docker, ContainerRuntime::Podman] { + if DockerManager::check_runtime_available(runtime).is_ok() { + DockerManager::clean_all_helix_images(runtime)?; + // Run system prune for this runtime + let output = std::process::Command::new(runtime.binary()) + .args(["system", "prune", "-f"]) + .output() + .map_err(|e| eyre::eyre!("Failed to run {} system prune: {e}", runtime.binary()))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + print_warning(&format!("{} system prune failed: {stderr}", runtime.label())); + } + } } print_success("System-wide Helix prune completed"); diff --git a/helix-cli/src/commands/push.rs b/helix-cli/src/commands/push.rs index fa0b54df7..0810ad28b 100644 --- a/helix-cli/src/commands/push.rs +++ b/helix-cli/src/commands/push.rs @@ -94,7 +94,7 @@ async fn push_local_instance( let docker = DockerManager::new(project); // Check Docker availability - DockerManager::check_docker_available()?; + DockerManager::check_runtime_available(docker.runtime)?; // Build the instance first (this ensures it's up to date) and get metrics data let metrics_data = diff --git a/helix-cli/src/commands/start.rs b/helix-cli/src/commands/start.rs index 1816eaac8..9335f124d 100644 --- a/helix-cli/src/commands/start.rs +++ b/helix-cli/src/commands/start.rs @@ -28,7 +28,7 @@ async fn start_local_instance(project: &ProjectContext, instance_name: &str) -> let docker = DockerManager::new(project); // Check Docker availability - DockerManager::check_docker_available()?; + DockerManager::check_runtime_available(docker.runtime)?; // Check if instance is built (has docker-compose.yml) let workspace = project.instance_workspace(instance_name); diff --git a/helix-cli/src/commands/status.rs b/helix-cli/src/commands/status.rs index e740e5a90..fd1b62e33 100644 --- a/helix-cli/src/commands/status.rs +++ b/helix-cli/src/commands/status.rs @@ -76,7 +76,8 @@ pub async fn run() -> Result<()> { async fn show_container_status(project: &ProjectContext) -> Result<()> { // Check if Docker is available - if DockerManager::check_docker_available().is_err() { + let runtime = project.config.project.container_runtime ; + if DockerManager::check_runtime_available(runtime).is_err() { print_field("Docker Status", "Not available"); return Ok(()); } diff --git a/helix-cli/src/commands/stop.rs b/helix-cli/src/commands/stop.rs index 289cc1c7d..7d8727de8 100644 --- a/helix-cli/src/commands/stop.rs +++ b/helix-cli/src/commands/stop.rs @@ -28,7 +28,7 @@ async fn stop_local_instance(project: &ProjectContext, instance_name: &str) -> R let docker = DockerManager::new(project); // Check Docker availability - DockerManager::check_docker_available()?; + DockerManager::check_runtime_available(docker.runtime)?; // Stop the instance docker.stop_instance(instance_name)?; diff --git a/helix-cli/src/config.rs b/helix-cli/src/config.rs index 5d282a4ea..d628b7457 100644 --- a/helix-cli/src/config.rs +++ b/helix-cli/src/config.rs @@ -25,6 +25,8 @@ pub struct ProjectConfig { deserialize_with = "deserialize_path" )] pub queries: PathBuf, + #[serde(default = "default_container_runtime")] + pub container_runtime: ContainerRuntime, } fn default_queries_path() -> PathBuf { @@ -46,7 +48,34 @@ where // Normalize path separators for cross-platform compatibility Ok(PathBuf::from(s.replace('\\', "/"))) } +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ContainerRuntime { + #[default] + Docker, + Podman, +} + +impl ContainerRuntime { + /// Get the CLI command name for this runtime + pub fn binary(&self) -> &'static str { + match self { + Self::Docker => "docker", + Self::Podman => "podman", + } + } + pub const fn label(&self) -> &'static str { + match self { + Self::Docker => "DOCKER", + Self::Podman => "PODMAN", + } + } +} + +fn default_container_runtime() -> ContainerRuntime { + ContainerRuntime::Docker +} #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct VectorConfig { #[serde(default = "default_m")] @@ -85,6 +114,7 @@ pub struct LocalInstanceConfig { pub build_mode: BuildMode, #[serde(flatten)] pub db_config: DbConfig, + } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -379,6 +409,7 @@ impl HelixConfig { port: Some(6969), build_mode: BuildMode::Debug, db_config: DbConfig::default(), + }, ); @@ -386,6 +417,7 @@ impl HelixConfig { project: ProjectConfig { name: project_name.to_string(), queries: default_queries_path(), + container_runtime: default_container_runtime(), }, local, cloud: HashMap::new(), diff --git a/helix-cli/src/docker.rs b/helix-cli/src/docker.rs index 50d7cbdcf..86f03499b 100644 --- a/helix-cli/src/docker.rs +++ b/helix-cli/src/docker.rs @@ -1,4 +1,9 @@ -use crate::config::{BuildMode, InstanceInfo}; +//! Container management using Docker-compatible runtimes (Docker/Podman), here Docker is more in a semantic sense. +//! +//! Despite the module name, this works with both Docker and Podman as they +//! share the same CLI interface and support standard Dockerfile formats. + +use crate::config::{BuildMode, InstanceInfo,ContainerRuntime}; use crate::project::ProjectContext; use crate::utils::{print_confirm, print_status, print_warning}; use eyre::{Result, eyre}; @@ -8,13 +13,18 @@ use std::time::Duration; pub struct DockerManager<'a> { project: &'a ProjectContext, + /// The container runtime to use (Docker or Podman) + pub(crate) runtime: ContainerRuntime, } + pub const HELIX_DATA_DIR: &str = "/data"; impl<'a> DockerManager<'a> { pub fn new(project: &'a ProjectContext) -> Self { - Self { project } + Self { project , + runtime : project.config.project.container_runtime, + } } // === CENTRALIZED NAMING METHODS === @@ -79,18 +89,18 @@ impl<'a> DockerManager<'a> { format!("{project_name}_net") } - // === CENTRALIZED DOCKER COMMAND EXECUTION === + // === CENTRALIZED DOCKER/PODMAN COMMAND EXECUTION === - /// Run a docker command with consistent error handling + /// Run a docker/podman command with consistent error handling pub fn run_docker_command(&self, args: &[&str]) -> Result { - let output = Command::new("docker") + let output = Command::new(self.runtime.binary()) .args(args) .output() - .map_err(|e| eyre!("Failed to run docker {}: {e}", args.join(" ")))?; + .map_err(|e| eyre!("Failed to run {} {}: {e}",self.runtime.binary(), args.join(" ")))?; Ok(output) } - /// Run a docker compose command with proper project naming + /// Run a docker/podman compose command with proper project naming fn run_compose_command(&self, instance_name: &str, args: Vec<&str>) -> Result { let workspace = self.project.instance_workspace(instance_name); let project_name = self.compose_project_name(instance_name); @@ -98,12 +108,12 @@ impl<'a> DockerManager<'a> { let mut full_args = vec!["--project-name", &project_name]; full_args.extend(args); - let output = Command::new("docker") + let output = Command::new(self.runtime.binary()) .arg("compose") .args(&full_args) .current_dir(&workspace) .output() - .map_err(|e| eyre!("Failed to run docker compose {}: {e}", full_args.join(" ")))?; + .map_err(|e| eyre!("Failed to run {} compose {}: {e}", self.runtime.binary() , full_args.join(" ")))?; Ok(output) } @@ -122,147 +132,243 @@ impl<'a> DockerManager<'a> { return "unknown"; } - /// Start the Docker daemon based on the platform - fn start_docker_daemon() -> Result<()> { - let platform = Self::detect_platform(); + /// Start the container runtime daemon based on platform and runtime +fn start_runtime_daemon(runtime: ContainerRuntime) -> Result<()> { + let platform = Self::detect_platform(); - match platform { - "macos" => { - print_status("DOCKER", "Starting Docker Desktop for macOS..."); - Command::new("open") - .args(["-a", "Docker"]) + match (runtime, platform) { + // Docker on macOS + (ContainerRuntime::Docker, "macos") => { + print_status("DOCKER", "Starting Docker Desktop for macOS..."); + Command::new("open") + .args(["-a", "Docker"]) + .output() + .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; + } + + // Podman on macOS + (ContainerRuntime::Podman, "macos") => { + print_status("PODMAN", "Starting Podman machine on macOS..."); + + // Check if machine exists first + let list_output = Command::new("podman") + .args(["machine", "list", "--format", "{{.Name}}"]) + .output() + .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; + + let machines = String::from_utf8_lossy(&list_output.stdout); + + if machines.trim().is_empty() { + // No machine exists, initialize one + print_status("PODMAN", "Initializing Podman machine (first time)..."); + let init_output = Command::new("podman") + .args(["machine", "init"]) .output() - .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; + .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; + + if !init_output.status.success() { + let stderr = String::from_utf8_lossy(&init_output.stderr); + return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + } } - "linux" => { - print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); - // Try systemctl first, then service command as fallback - let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); - - match systemctl_result { - Ok(output) if output.status.success() => { - // systemctl succeeded - } - _ => { - // Try service command as fallback - let service_result = Command::new("service") - .args(["docker", "start"]) - .output() - .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; - - if !service_result.status.success() { - let stderr = String::from_utf8_lossy(&service_result.stderr); - return Err(eyre!("Failed to start Docker daemon: {}", stderr)); - } + + // Start the machine + Command::new("podman") + .args(["machine", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; + } + + // Docker on Linux + (ContainerRuntime::Docker, "linux") => { + print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); + let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); + + match systemctl_result { + Ok(output) if output.status.success() => {} + _ => { + let service_result = Command::new("service") + .args(["docker", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; + + if !service_result.status.success() { + let stderr = String::from_utf8_lossy(&service_result.stderr); + return Err(eyre!("Failed to start Docker daemon: {}", stderr)); } } } - "windows" => { - print_status("DOCKER", "Starting Docker Desktop for Windows..."); - // Try Docker Desktop CLI (4.37+) first - let cli_result = Command::new("docker") - .args(["desktop", "start"]) + } + + // Podman on Linux + (ContainerRuntime::Podman, "linux") => { + print_status("PODMAN", "Starting Podman service on Linux..."); + + // Try to start user service (rootless) + let user_service = Command::new("systemctl") + .args(["--user", "start", "podman.socket"]) + .output(); + + if user_service.is_err() || !user_service.unwrap().status.success() { + // Try system service (rootful) as fallback + let system_service = Command::new("systemctl") + .args(["start", "podman.socket"]) .output(); + + if let Err(e) = system_service { + print_warning(&format!("Could not start Podman service: {}", e)); + } + } + } - match cli_result { - Ok(output) if output.status.success() => { - // Modern Docker Desktop CLI worked - } - _ => { - // Fallback to direct executable path for older versions - // Note: Empty string "" is required as window title parameter - Command::new("cmd") - .args(["/c", "start", "", "\"C:\\Program Files\\Docker\\Docker\\Docker Desktop.exe\""]) - .output() - .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; - } + // Docker on Windows + (ContainerRuntime::Docker, "windows") => { + print_status("DOCKER", "Starting Docker Desktop for Windows..."); + let cli_result = Command::new("docker") + .args(["desktop", "start"]) + .output(); + + match cli_result { + Ok(output) if output.status.success() => {} + _ => { + Command::new("cmd") + .args(["/c", "start", "", "\"C:\\Program Files\\Docker\\Docker\\Docker Desktop.exe\""]) + .output() + .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; } } - _ => { - return Err(eyre!("Unsupported platform for auto-starting Docker")); + } + + // Podman on Windows + (ContainerRuntime::Podman, "windows") => { + print_status("PODMAN", "Starting Podman machine on Windows..."); + + // Check if machine exists + let list_output = Command::new("podman") + .args(["machine", "list", "--format", "{{.Name}}"]) + .output() + .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; + + let machines = String::from_utf8_lossy(&list_output.stdout); + + if machines.trim().is_empty() { + // Initialize machine first + print_status("PODMAN", "Initializing Podman machine (first time)..."); + let init_output = Command::new("podman") + .args(["machine", "init"]) + .output() + .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; + + if !init_output.status.success() { + let stderr = String::from_utf8_lossy(&init_output.stderr); + return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + } + } + + // Start the machine + let start_output = Command::new("podman") + .args(["machine", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; + + if !start_output.status.success() { + let stderr = String::from_utf8_lossy(&start_output.stderr); + return Err(eyre!("Failed to start Podman machine: {}", stderr)); } } - Ok(()) + (_, platform) => { + return Err(eyre!("Unsupported platform '{}' for auto-starting {}", platform, runtime.label())); + } } - /// Wait for Docker daemon to be ready - fn wait_for_docker(timeout_secs: u64) -> Result<()> { - print_status("DOCKER", "Waiting for Docker daemon to start..."); + Ok(()) +} - let start = std::time::Instant::now(); - let timeout = Duration::from_secs(timeout_secs); +fn wait_for_runtime(runtime: ContainerRuntime, timeout_secs: u64) -> Result<()> { + print_status(runtime.label(), "Waiting for daemon to start..."); - while start.elapsed() < timeout { - // Check if Docker daemon is responding - let output = Command::new("docker").args(["info"]).output(); + let start = std::time::Instant::now(); + let timeout = Duration::from_secs(timeout_secs); - if let Ok(output) = output - && output.status.success() - { - print_status("DOCKER", "Docker daemon is now running"); - return Ok(()); - } + while start.elapsed() < timeout { + let output = Command::new(runtime.binary()).args(["info"]).output(); - // Wait a bit before retrying - thread::sleep(Duration::from_millis(500)); + if let Ok(output) = output && output.status.success() { + print_status(runtime.label(), "Daemon is now running"); + return Ok(()); } - Err(eyre!( - "Timeout waiting for Docker daemon to start. Please start Docker manually and try again." - )) + thread::sleep(Duration::from_millis(500)); } - /// Check if Docker is installed and running - pub fn check_docker_available() -> Result<()> { - let output = Command::new("docker") - .args(["--version"]) - .output() - .map_err(|_| eyre!("Docker is not installed or not available in PATH"))?; - - if !output.status.success() { - return Err(eyre!("Docker is installed but not working properly")); - } - - // Check if Docker daemon is running - let output = Command::new("docker") - .args(["info"]) - .output() - .map_err(|_| eyre!("Failed to check Docker daemon status"))?; + Err(eyre!( + "Timeout waiting for {} daemon to start. Please start {} manually and try again.", + runtime.label(), + runtime.binary() + )) +} - if !output.status.success() { - // Docker daemon is not running - prompt user - let should_start = - print_confirm("Docker daemon is not running. Would you like to start Docker?") - .unwrap_or(false); + /// Check if container runtime is installed and running, with auto-start option +pub fn check_runtime_available(runtime: ContainerRuntime) -> Result<()> { + let cmd = runtime.binary(); + + let output = Command::new(cmd) + .args(["--version"]) + .output() + .map_err(|_| eyre!("{} is not installed or not available in PATH", cmd))?; + + if !output.status.success() { + return Err(eyre!("{} is installed but not working properly", cmd)); + } - if should_start { - // Try to start Docker - Self::start_docker_daemon()?; + // Check if daemon is running + let output = Command::new(cmd) + .args(["info"]) + .output() + .map_err(|_| eyre!("Failed to check {} daemon status", cmd))?; + + if !output.status.success() { + // Daemon not running - ask user if they want to start it + let message = format!( + "{} daemon is not running. Would you like to start {}?", + runtime.label(), + runtime.binary() + ); + let should_start = print_confirm(&message).unwrap_or(false); - // Wait for Docker to be ready - Self::wait_for_docker(15)?; + if should_start { + // Try to start the runtime + Self::start_runtime_daemon(runtime)?; - // Verify Docker is now running - let verify_output = Command::new("docker") - .args(["info"]) - .output() - .map_err(|_| eyre!("Failed to verify Docker daemon status"))?; + // Wait for it to be ready + Self::wait_for_runtime(runtime, 15)?; - if !verify_output.status.success() { - return Err(eyre!( - "Docker daemon failed to start. Please start Docker manually and try again." - )); - } - } else { - print_warning("Docker daemon must be running to execute this command."); - return Err(eyre!("Docker daemon is not running. Please start Docker.")); + // Verify it's running now + let verify_output = Command::new(cmd) + .args(["info"]) + .output() + .map_err(|_| eyre!("Failed to verify {} daemon status", cmd))?; + + if !verify_output.status.success() { + return Err(eyre!( + "{} daemon failed to start. Please start {} manually and try again.", + runtime.label(), + cmd + )); } + } else { + print_warning(&format!( + "{} daemon must be running to execute this command.", + runtime.label() + )); + return Err(eyre!("{} daemon is not running. Please start {}.", cmd, cmd)); } - - Ok(()) } + Ok(()) +} + /// Generate Dockerfile for an instance pub fn generate_dockerfile( &self, @@ -391,10 +497,10 @@ networks: Ok(compose) } - /// Build Docker image for an instance + /// Build Docker/Podman image for an instance pub fn build_image(&self, instance_name: &str, _build_target: Option<&str>) -> Result<()> { print_status( - "DOCKER", + self.runtime.label(), &format!("Building image for instance '{instance_name}'..."), ); @@ -402,16 +508,16 @@ networks: if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(eyre!("Docker build failed:\n{stderr}")); + return Err(eyre!("{} build failed:\n{stderr}",self.runtime.binary())); } - print_status("DOCKER", "Image built successfully"); + print_status(self.runtime.label(), "Image built successfully"); Ok(()) } - /// Start instance using docker compose + /// Start instance using docker/podman compose pub fn start_instance(&self, instance_name: &str) -> Result<()> { - print_status("DOCKER", &format!("Starting instance '{instance_name}'...")); + print_status(self.runtime.label(), &format!("Starting instance '{instance_name}'...")); let output = self.run_compose_command(instance_name, vec!["up", "-d"])?; @@ -421,15 +527,15 @@ networks: } print_status( - "DOCKER", + self.runtime.label(), &format!("Instance '{instance_name}' started successfully"), ); Ok(()) } - /// Stop instance using docker compose + /// Stop instance using docker/podman compose pub fn stop_instance(&self, instance_name: &str) -> Result<()> { - print_status("DOCKER", &format!("Stopping instance '{instance_name}'...")); + print_status(self.runtime.label(), &format!("Stopping instance '{instance_name}'...")); let output = self.run_compose_command(instance_name, vec!["down"])?; @@ -439,7 +545,7 @@ networks: } print_status( - "DOCKER", + self.runtime.label(), &format!("Instance '{instance_name}' stopped successfully"), ); Ok(()) @@ -455,7 +561,7 @@ networks: .any(|status| status.container_name.contains(&target_container_name))) } - /// Get status of all Docker containers for this project + /// Get status of all Docker/Podman containers for this project pub fn get_project_status(&self) -> Result> { let project_name = &self.project.config.project.name; let filter = format!("name=helix-{project_name}-"); @@ -514,13 +620,13 @@ networks: /// Remove instance containers and optionally volumes pub fn prune_instance(&self, instance_name: &str, remove_volumes: bool) -> Result<()> { - print_status("DOCKER", &format!("Pruning instance '{instance_name}'...")); + print_status(self.runtime.label(), &format!("Pruning instance '{instance_name}'...")); // Check if workspace exists - if not, there's nothing to prune let workspace = self.project.instance_workspace(instance_name); if !workspace.exists() { print_status( - "DOCKER", + self.runtime.label(), &format!("No workspace found for instance '{instance_name}', nothing to prune"), ); return Ok(()); @@ -530,7 +636,7 @@ networks: let compose_file = workspace.join("docker-compose.yml"); if !compose_file.exists() { print_status( - "DOCKER", + self.runtime.label(), &format!( "No docker-compose.yml found for instance '{instance_name}', nothing to prune" ), @@ -552,7 +658,7 @@ networks: // Don't fail if containers are already down if stderr.contains("No such container") || stderr.contains("not running") { print_status( - "DOCKER", + self.runtime.label(), &format!("Instance '{instance_name}' containers already stopped"), ); } else { @@ -560,7 +666,7 @@ networks: } } else { print_status( - "DOCKER", + self.runtime.label(), &format!("Instance '{instance_name}' pruned successfully"), ); } @@ -579,13 +685,13 @@ networks: let _ = self.run_docker_command(&["volume", "rm", &volume_to_remove]); } - Ok(()) + Ok(()) } - /// Remove Docker images associated with an instance + /// Remove Docker/Podman images associated with an instance pub fn remove_instance_images(&self, instance_name: &str) -> Result<()> { print_status( - "DOCKER", + self.runtime.label(), &format!("Removing images for instance '{instance_name}'..."), ); @@ -597,16 +703,16 @@ networks: for image in [debug_image, release_image] { let output = self.run_docker_command(&["rmi", "-f", &image])?; if output.status.success() { - print_status("DOCKER", &format!("Removed image: {image}")); + print_status(self.runtime.label(), &format!("Removed image: {image}")); } } Ok(()) } - /// Get all Helix-related Docker images from the system - pub fn get_helix_images() -> Result> { - let output = Command::new("docker") + /// Get all Helix-related images from the system + pub fn get_helix_images(runtime: ContainerRuntime) -> Result> { + let output = Command::new(runtime.binary()) .args([ "images", "--format", @@ -615,7 +721,7 @@ networks: "reference=helix-*", ]) .output() - .map_err(|e| eyre!("Failed to list Docker images: {e}"))?; + .map_err(|e| eyre!("Failed to list {} images: {e}",runtime.binary()))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); @@ -632,32 +738,32 @@ networks: Ok(images) } - /// Remove all Helix-related Docker images from the system - pub fn clean_all_helix_images() -> Result<()> { - print_status("DOCKER", "Finding all Helix images on system..."); + /// Remove all Helix-related images from the system + pub fn clean_all_helix_images(runtime: ContainerRuntime) -> Result<()> { + print_status(runtime.label(), "Finding all Helix images on system..."); - let images = Self::get_helix_images()?; + let images = Self::get_helix_images(runtime)?; if images.is_empty() { - print_status("DOCKER", "No Helix images found to clean"); + print_status(runtime.label(), "No Helix images found to clean"); return Ok(()); } let count = images.len(); - print_status("DOCKER", &format!("Found {count} Helix images to remove")); + print_status(runtime.label(), &format!("Found {count} Helix images to remove")); for image in images { - let output = Command::new("docker") + let output = Command::new(runtime.binary()) .args(["rmi", "-f", &image]) .output() .map_err(|e| eyre!("Failed to remove image {image}: {e}"))?; if output.status.success() { - print_status("DOCKER", &format!("Removed image: {image}")); + print_status(runtime.label(), &format!("Removed image: {image}")); } else { let stderr = String::from_utf8_lossy(&output.stderr); print_status( - "DOCKER", + runtime.label(), &format!("Warning: Failed to remove {image}: {stderr}"), ); } @@ -668,7 +774,7 @@ networks: pub fn tag(&self, image_name: &str, registry_url: &str) -> Result<()> { let registry_image = format!("{registry_url}/{image_name}"); - Command::new("docker") + Command::new(self.runtime.binary()) .arg("tag") .arg(image_name) .arg(®istry_image) @@ -679,8 +785,8 @@ networks: pub fn push(&self, image_name: &str, registry_url: &str) -> Result<()> { let registry_image = format!("{registry_url}/{image_name}"); - print_status("DOCKER", &format!("Pushing image: {registry_image}")); - let output = Command::new("docker") + print_status(self.runtime.label(), &format!("Pushing image: {registry_image}")); + let output = Command::new(self.runtime.binary()) .arg("push") .arg(®istry_image) .output()?; diff --git a/helix-cli/src/tests/check_tests.rs b/helix-cli/src/tests/check_tests.rs index 6972f5027..f87cdb2a7 100644 --- a/helix-cli/src/tests/check_tests.rs +++ b/helix-cli/src/tests/check_tests.rs @@ -234,6 +234,7 @@ async fn test_check_with_multiple_instances() { port: Some(6970), build_mode: crate::config::BuildMode::Debug, db_config: DbConfig::default(), + }, ); config.local.insert( @@ -242,6 +243,7 @@ async fn test_check_with_multiple_instances() { port: Some(6971), build_mode: crate::config::BuildMode::Debug, db_config: DbConfig::default(), + }, ); let config_path = project_path.join("helix.toml"); From 2841ace0012b1fda53bb32052868a93ee3e19f39 Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 14 Nov 2025 16:39:13 -0800 Subject: [PATCH 02/48] making docker use openai/gemini keys --- helix-cli/src/docker.rs | 171 +++++++--------------------------------- 1 file changed, 27 insertions(+), 144 deletions(-) diff --git a/helix-cli/src/docker.rs b/helix-cli/src/docker.rs index 50d7cbdcf..6347de851 100644 --- a/helix-cli/src/docker.rs +++ b/helix-cli/src/docker.rs @@ -1,10 +1,8 @@ use crate::config::{BuildMode, InstanceInfo}; use crate::project::ProjectContext; -use crate::utils::{print_confirm, print_status, print_warning}; +use crate::utils::print_status; use eyre::{Result, eyre}; use std::process::{Command, Output}; -use std::thread; -use std::time::Duration; pub struct DockerManager<'a> { project: &'a ProjectContext, @@ -46,8 +44,12 @@ impl<'a> DockerManager<'a> { } /// Get environment variables for an instance + /// Loads from .env file and shell environment pub(crate) fn environment_variables(&self, instance_name: &str) -> Vec { - vec![ + // Load .env file (silently ignore if it doesn't exist) + let _ = dotenvy::dotenv(); + + let mut env_vars = vec![ { let port = self .project @@ -64,7 +66,17 @@ impl<'a> DockerManager<'a> { let project_name = &self.project.config.project.name; format!("HELIX_PROJECT={project_name}") }, - ] + ]; + + // Add API keys from environment (which includes .env after dotenv() call) + if let Ok(openai_key) = std::env::var("OPENAI_API_KEY") { + env_vars.push(format!("OPENAI_API_KEY={openai_key}")); + } + if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") { + env_vars.push(format!("GEMINI_API_KEY={gemini_key}")); + } + + env_vars } /// Get the container name for an instance @@ -107,112 +119,6 @@ impl<'a> DockerManager<'a> { Ok(output) } - /// Detect the current operating system platform - fn detect_platform() -> &'static str { - #[cfg(target_os = "macos")] - return "macos"; - - #[cfg(target_os = "linux")] - return "linux"; - - #[cfg(target_os = "windows")] - return "windows"; - - #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] - return "unknown"; - } - - /// Start the Docker daemon based on the platform - fn start_docker_daemon() -> Result<()> { - let platform = Self::detect_platform(); - - match platform { - "macos" => { - print_status("DOCKER", "Starting Docker Desktop for macOS..."); - Command::new("open") - .args(["-a", "Docker"]) - .output() - .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; - } - "linux" => { - print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); - // Try systemctl first, then service command as fallback - let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); - - match systemctl_result { - Ok(output) if output.status.success() => { - // systemctl succeeded - } - _ => { - // Try service command as fallback - let service_result = Command::new("service") - .args(["docker", "start"]) - .output() - .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; - - if !service_result.status.success() { - let stderr = String::from_utf8_lossy(&service_result.stderr); - return Err(eyre!("Failed to start Docker daemon: {}", stderr)); - } - } - } - } - "windows" => { - print_status("DOCKER", "Starting Docker Desktop for Windows..."); - // Try Docker Desktop CLI (4.37+) first - let cli_result = Command::new("docker") - .args(["desktop", "start"]) - .output(); - - match cli_result { - Ok(output) if output.status.success() => { - // Modern Docker Desktop CLI worked - } - _ => { - // Fallback to direct executable path for older versions - // Note: Empty string "" is required as window title parameter - Command::new("cmd") - .args(["/c", "start", "", "\"C:\\Program Files\\Docker\\Docker\\Docker Desktop.exe\""]) - .output() - .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; - } - } - } - _ => { - return Err(eyre!("Unsupported platform for auto-starting Docker")); - } - } - - Ok(()) - } - - /// Wait for Docker daemon to be ready - fn wait_for_docker(timeout_secs: u64) -> Result<()> { - print_status("DOCKER", "Waiting for Docker daemon to start..."); - - let start = std::time::Instant::now(); - let timeout = Duration::from_secs(timeout_secs); - - while start.elapsed() < timeout { - // Check if Docker daemon is responding - let output = Command::new("docker").args(["info"]).output(); - - if let Ok(output) = output - && output.status.success() - { - print_status("DOCKER", "Docker daemon is now running"); - return Ok(()); - } - - // Wait a bit before retrying - thread::sleep(Duration::from_millis(500)); - } - - Err(eyre!( - "Timeout waiting for Docker daemon to start. Please start Docker manually and try again." - )) - } - /// Check if Docker is installed and running pub fn check_docker_available() -> Result<()> { let output = Command::new("docker") @@ -231,33 +137,7 @@ impl<'a> DockerManager<'a> { .map_err(|_| eyre!("Failed to check Docker daemon status"))?; if !output.status.success() { - // Docker daemon is not running - prompt user - let should_start = - print_confirm("Docker daemon is not running. Would you like to start Docker?") - .unwrap_or(false); - - if should_start { - // Try to start Docker - Self::start_docker_daemon()?; - - // Wait for Docker to be ready - Self::wait_for_docker(15)?; - - // Verify Docker is now running - let verify_output = Command::new("docker") - .args(["info"]) - .output() - .map_err(|_| eyre!("Failed to verify Docker daemon status"))?; - - if !verify_output.status.success() { - return Err(eyre!( - "Docker daemon failed to start. Please start Docker manually and try again." - )); - } - } else { - print_warning("Docker daemon must be running to execute this command."); - return Err(eyre!("Docker daemon is not running. Please start Docker.")); - } + return Err(eyre!("Docker daemon is not running. Please start Docker.")); } Ok(()) @@ -354,6 +234,14 @@ CMD ["helix-container"] let container_name = self.container_name(instance_name); let network_name = self.network_name(instance_name); + // Get all environment variables dynamically + let env_vars = self.environment_variables(instance_name); + let env_section = env_vars + .iter() + .map(|var| format!(" - {var}")) + .collect::>() + .join("\n"); + let compose = format!( r#"# Generated docker-compose.yml for Helix instance: {instance_name} services: @@ -369,10 +257,7 @@ services: volumes: - ../.volumes/{instance_name}:/data environment: - - HELIX_PORT={port} - - HELIX_DATA_DIR={data_dir} - - HELIX_INSTANCE={instance_name} - - HELIX_PROJECT={project_name} +{env_section} restart: unless-stopped networks: - {network_name} @@ -384,8 +269,6 @@ networks: platform = instance_config .docker_build_target() .map_or("".to_string(), |p| format!("platforms:\n - {p}")), - project_name = self.project.config.project.name, - data_dir = HELIX_DATA_DIR, ); Ok(compose) From 4edfa321001af7d15cb6ce1f2dd02788cebabb93 Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 14 Nov 2025 16:43:19 -0800 Subject: [PATCH 03/48] fixing env var stuff to reflect all recent docker changes --- helix-cli/src/docker.rs | 146 +++++++++++++++++++++++++++++++++++++-- helix-db-fix-entry-point | 1 + 2 files changed, 141 insertions(+), 6 deletions(-) create mode 160000 helix-db-fix-entry-point diff --git a/helix-cli/src/docker.rs b/helix-cli/src/docker.rs index 6347de851..d9a5b4afd 100644 --- a/helix-cli/src/docker.rs +++ b/helix-cli/src/docker.rs @@ -1,8 +1,10 @@ use crate::config::{BuildMode, InstanceInfo}; use crate::project::ProjectContext; -use crate::utils::print_status; +use crate::utils::{print_confirm, print_status, print_warning}; use eyre::{Result, eyre}; use std::process::{Command, Output}; +use std::thread; +use std::time::Duration; pub struct DockerManager<'a> { project: &'a ProjectContext, @@ -44,7 +46,6 @@ impl<'a> DockerManager<'a> { } /// Get environment variables for an instance - /// Loads from .env file and shell environment pub(crate) fn environment_variables(&self, instance_name: &str) -> Vec { // Load .env file (silently ignore if it doesn't exist) let _ = dotenvy::dotenv(); @@ -119,6 +120,115 @@ impl<'a> DockerManager<'a> { Ok(output) } + /// Detect the current operating system platform + fn detect_platform() -> &'static str { + #[cfg(target_os = "macos")] + return "macos"; + + #[cfg(target_os = "linux")] + return "linux"; + + #[cfg(target_os = "windows")] + return "windows"; + + #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] + return "unknown"; + } + + /// Start the Docker daemon based on the platform + fn start_docker_daemon() -> Result<()> { + let platform = Self::detect_platform(); + + match platform { + "macos" => { + print_status("DOCKER", "Starting Docker Desktop for macOS..."); + Command::new("open") + .args(["-a", "Docker"]) + .output() + .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; + } + "linux" => { + print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); + // Try systemctl first, then service command as fallback + let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); + + match systemctl_result { + Ok(output) if output.status.success() => { + // systemctl succeeded + } + _ => { + // Try service command as fallback + let service_result = Command::new("service") + .args(["docker", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; + + if !service_result.status.success() { + let stderr = String::from_utf8_lossy(&service_result.stderr); + return Err(eyre!("Failed to start Docker daemon: {}", stderr)); + } + } + } + } + "windows" => { + print_status("DOCKER", "Starting Docker Desktop for Windows..."); + // Try Docker Desktop CLI (4.37+) first + let cli_result = Command::new("docker").args(["desktop", "start"]).output(); + + match cli_result { + Ok(output) if output.status.success() => { + // Modern Docker Desktop CLI worked + } + _ => { + // Fallback to direct executable path for older versions + // Note: Empty string "" is required as window title parameter + Command::new("cmd") + .args([ + "/c", + "start", + "", + "\"C:\\Program Files\\Docker\\Docker\\Docker Desktop.exe\"", + ]) + .output() + .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; + } + } + } + _ => { + return Err(eyre!("Unsupported platform for auto-starting Docker")); + } + } + + Ok(()) + } + + /// Wait for Docker daemon to be ready + fn wait_for_docker(timeout_secs: u64) -> Result<()> { + print_status("DOCKER", "Waiting for Docker daemon to start..."); + + let start = std::time::Instant::now(); + let timeout = Duration::from_secs(timeout_secs); + + while start.elapsed() < timeout { + // Check if Docker daemon is responding + let output = Command::new("docker").args(["info"]).output(); + + if let Ok(output) = output + && output.status.success() + { + print_status("DOCKER", "Docker daemon is now running"); + return Ok(()); + } + + // Wait a bit before retrying + thread::sleep(Duration::from_millis(500)); + } + + Err(eyre!( + "Timeout waiting for Docker daemon to start. Please start Docker manually and try again." + )) + } + /// Check if Docker is installed and running pub fn check_docker_available() -> Result<()> { let output = Command::new("docker") @@ -137,7 +247,33 @@ impl<'a> DockerManager<'a> { .map_err(|_| eyre!("Failed to check Docker daemon status"))?; if !output.status.success() { - return Err(eyre!("Docker daemon is not running. Please start Docker.")); + // Docker daemon is not running - prompt user + let should_start = + print_confirm("Docker daemon is not running. Would you like to start Docker?") + .unwrap_or(false); + + if should_start { + // Try to start Docker + Self::start_docker_daemon()?; + + // Wait for Docker to be ready + Self::wait_for_docker(15)?; + + // Verify Docker is now running + let verify_output = Command::new("docker") + .args(["info"]) + .output() + .map_err(|_| eyre!("Failed to verify Docker daemon status"))?; + + if !verify_output.status.success() { + return Err(eyre!( + "Docker daemon failed to start. Please start Docker manually and try again." + )); + } + } else { + print_warning("Docker daemon must be running to execute this command."); + return Err(eyre!("Docker daemon is not running. Please start Docker.")); + } } Ok(()) @@ -232,9 +368,7 @@ CMD ["helix-container"] let service_name = Self::service_name(); let image_name = self.image_name(instance_name, instance_config.build_mode()); let container_name = self.container_name(instance_name); - let network_name = self.network_name(instance_name); - - // Get all environment variables dynamically + let network_name = self.network_name(instance_name); // Get all environment variables dynamically let env_vars = self.environment_variables(instance_name); let env_section = env_vars .iter() diff --git a/helix-db-fix-entry-point b/helix-db-fix-entry-point new file mode 160000 index 000000000..583e4cbf5 --- /dev/null +++ b/helix-db-fix-entry-point @@ -0,0 +1 @@ +Subproject commit 583e4cbf596b7846f33a4f037fc81a4076c8c759 From ac996905ae6438bbd13a1663f3255e2fdab8aa17 Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 14 Nov 2025 22:28:14 -0800 Subject: [PATCH 04/48] fixing issue with embedding config not being persisted --- helix-cli/src/commands/migrate.rs | 3 +++ helix-cli/src/config.rs | 36 +++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/helix-cli/src/commands/migrate.rs b/helix-cli/src/commands/migrate.rs index 6b7707600..7a8f7cdda 100644 --- a/helix-cli/src/commands/migrate.rs +++ b/helix-cli/src/commands/migrate.rs @@ -408,6 +408,9 @@ fn create_v2_config(ctx: &MigrationContext) -> Result<()> { graph_config, mcp: ctx.v1_config.mcp, bm25: ctx.v1_config.bm25, + schema: None, + embedding_model: Some("text-embedding-ada-002".to_string()), + graphvis_node_label: None, }; // Create local instance config diff --git a/helix-cli/src/config.rs b/helix-cli/src/config.rs index 5d282a4ea..1752d0271 100644 --- a/helix-cli/src/config.rs +++ b/helix-cli/src/config.rs @@ -75,6 +75,12 @@ pub struct DbConfig { pub mcp: bool, #[serde(default = "default_true", skip_serializing_if = "is_true")] pub bm25: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub schema: Option, + #[serde(default = "default_embedding_model", skip_serializing_if = "is_default_embedding_model")] + pub embedding_model: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub graphvis_node_label: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -159,6 +165,14 @@ fn default_db_max_size_gb() -> u32 { 20 } +fn default_embedding_model() -> Option { + Some("text-embedding-ada-002".to_string()) +} + +fn is_default_embedding_model(value: &Option) -> bool { + *value == default_embedding_model() +} + fn is_true(value: &bool) -> bool { *value } @@ -189,6 +203,9 @@ impl Default for DbConfig { graph_config: GraphConfig::default(), mcp: true, bm25: true, + schema: None, + embedding_model: default_embedding_model(), + graphvis_node_label: None, } } } @@ -259,7 +276,7 @@ impl<'a> InstanceInfo<'a> { pub fn to_legacy_json(&self) -> serde_json::Value { let db_config = self.db_config(); - serde_json::json!({ + let mut json = serde_json::json!({ "vector_config": { "m": db_config.vector_config.m, "ef_construction": db_config.vector_config.ef_construction, @@ -272,7 +289,22 @@ impl<'a> InstanceInfo<'a> { "db_max_size_gb": db_config.vector_config.db_max_size_gb, "mcp": db_config.mcp, "bm25": db_config.bm25 - }) + }); + + // Add optional fields if they exist + if let Some(schema) = &db_config.schema { + json["schema"] = serde_json::Value::String(schema.clone()); + } + + if let Some(embedding_model) = &db_config.embedding_model { + json["embedding_model"] = serde_json::Value::String(embedding_model.clone()); + } + + if let Some(graphvis_node_label) = &db_config.graphvis_node_label { + json["graphvis_node_label"] = serde_json::Value::String(graphvis_node_label.clone()); + } + + json } } From 6725d5f64fd0ea7f264f2df44f3ae33b2e99617b Mon Sep 17 00:00:00 2001 From: xav-db Date: Tue, 18 Nov 2025 22:30:06 -0800 Subject: [PATCH 05/48] removing unused file --- helix-db-fix-entry-point | 1 - 1 file changed, 1 deletion(-) delete mode 160000 helix-db-fix-entry-point diff --git a/helix-db-fix-entry-point b/helix-db-fix-entry-point deleted file mode 160000 index 583e4cbf5..000000000 --- a/helix-db-fix-entry-point +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 583e4cbf596b7846f33a4f037fc81a4076c8c759 From 2c78f58f41720f361a6f0c5e4afeaf6ce4bd2a8d Mon Sep 17 00:00:00 2001 From: xav-db Date: Wed, 19 Nov 2025 10:54:33 -0800 Subject: [PATCH 06/48] implementing auto init and add cleanup when running cli --- helix-cli/src/cleanup.rs | 236 +++++++++++++++++++++++++++++++++ helix-cli/src/commands/add.rs | 26 ++++ helix-cli/src/commands/init.rs | 86 ++++++++++-- helix-cli/src/lib.rs | 1 + helix-cli/src/main.rs | 1 + 5 files changed, 337 insertions(+), 13 deletions(-) create mode 100644 helix-cli/src/cleanup.rs diff --git a/helix-cli/src/cleanup.rs b/helix-cli/src/cleanup.rs new file mode 100644 index 000000000..a7d746f1f --- /dev/null +++ b/helix-cli/src/cleanup.rs @@ -0,0 +1,236 @@ +use std::fs; +use std::path::PathBuf; + +use crate::config::HelixConfig; + +/// Tracks resources created during init/add operations for automatic cleanup on failure +pub struct CleanupTracker { + /// Files created during the operation (tracked in creation order) + created_files: Vec, + /// Directories created during the operation (tracked in creation order) + created_dirs: Vec, + /// In-memory backup of the config before modification + original_config: Option, + /// Path to the config file + config_path: Option, +} + +/// Summary of cleanup operations +pub struct CleanupSummary { + pub files_removed: usize, + pub files_failed: usize, + pub dirs_removed: usize, + pub dirs_failed: usize, + pub config_restored: bool, + pub errors: Vec, +} + +impl CleanupTracker { + /// Create a new cleanup tracker + pub fn new() -> Self { + Self { + created_files: Vec::new(), + created_dirs: Vec::new(), + original_config: None, + config_path: None, + } + } + + /// Track a file that was created + pub fn track_file(&mut self, path: PathBuf) { + self.created_files.push(path); + } + + /// Track a directory that was created + pub fn track_dir(&mut self, path: PathBuf) { + self.created_dirs.push(path); + } + + /// Backup the config in memory before modification + pub fn backup_config(&mut self, config: &HelixConfig, config_path: PathBuf) { + self.original_config = Some(config.clone()); + self.config_path = Some(config_path); + } + + /// Execute cleanup in reverse order of creation + /// Logs errors but continues cleanup process + pub fn cleanup(self) -> CleanupSummary { + let mut summary = CleanupSummary { + files_removed: 0, + files_failed: 0, + dirs_removed: 0, + dirs_failed: 0, + config_restored: false, + errors: Vec::new(), + }; + + // Step 1: Restore config from in-memory backup if modified + if let (Some(original_config), Some(config_path)) = + (self.original_config, self.config_path) + { + match original_config.save_to_file(&config_path) { + Ok(_) => { + summary.config_restored = true; + eprintln!("Restored config file to original state"); + } + Err(e) => { + let error_msg = format!("Failed to restore config: {}", e); + eprintln!("Error: {}", error_msg); + summary.errors.push(error_msg); + } + } + } + + // Step 2: Delete files in reverse order (newest first) + for file_path in self.created_files.iter().rev() { + match fs::remove_file(file_path) { + Ok(_) => { + summary.files_removed += 1; + eprintln!("Removed file: {}", file_path.display()); + } + Err(e) => { + summary.files_failed += 1; + let error_msg = format!("Failed to remove file {}: {}", file_path.display(), e); + eprintln!("Warning: {}", error_msg); + summary.errors.push(error_msg); + } + } + } + + // Step 3: Delete directories in reverse order (deepest first) + // Sort by path depth (deepest first) to ensure we delete children before parents + let mut sorted_dirs = self.created_dirs.clone(); + sorted_dirs.sort_by(|a, b| { + let a_depth = a.components().count(); + let b_depth = b.components().count(); + b_depth.cmp(&a_depth) // Reverse order (deepest first) + }); + + for dir_path in sorted_dirs.iter() { + // Only try to remove if directory exists + if !dir_path.exists() { + continue; + } + + // Try to remove directory - will only succeed if empty + match fs::remove_dir(dir_path) { + Ok(_) => { + summary.dirs_removed += 1; + eprintln!("Removed directory: {}", dir_path.display()); + } + Err(_e) => { + // This might fail if directory is not empty, which is fine + // We only want to remove directories we created if they're still empty + summary.dirs_failed += 1; + // Don't add to errors since this is expected for non-empty dirs + } + } + } + + summary + } + + /// Check if any resources are being tracked + pub fn has_tracked_resources(&self) -> bool { + !self.created_files.is_empty() + || !self.created_dirs.is_empty() + || self.original_config.is_some() + } +} + +impl CleanupSummary { + /// Log the cleanup summary + pub fn log_summary(&self) { + if self.files_removed > 0 || self.dirs_removed > 0 || self.config_restored { + eprintln!("Cleanup summary:"); + if self.config_restored { + eprintln!(" - Config file restored"); + } + if self.files_removed > 0 { + eprintln!(" - Removed {} file(s)", self.files_removed); + } + if self.dirs_removed > 0 { + eprintln!(" - Removed {} directory(ies)", self.dirs_removed); + } + } + + if self.files_failed > 0 || self.dirs_failed > 0 || !self.errors.is_empty() { + eprintln!("Cleanup encountered {} error(s):", self.errors.len()); + for error in &self.errors { + eprintln!(" - {}", error); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::io::Write; + use tempfile::TempDir; + + #[test] + fn test_track_and_cleanup_files() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + // Create a file + let mut file = fs::File::create(&file_path).unwrap(); + file.write_all(b"test").unwrap(); + + let mut tracker = CleanupTracker::new(); + tracker.track_file(file_path.clone()); + + assert!(file_path.exists()); + + // Cleanup should remove the file + let summary = tracker.cleanup(); + assert_eq!(summary.files_removed, 1); + assert_eq!(summary.files_failed, 0); + assert!(!file_path.exists()); + } + + #[test] + fn test_track_and_cleanup_dirs() { + let temp_dir = TempDir::new().unwrap(); + let dir_path = temp_dir.path().join("test_dir"); + + // Create a directory + fs::create_dir(&dir_path).unwrap(); + + let mut tracker = CleanupTracker::new(); + tracker.track_dir(dir_path.clone()); + + assert!(dir_path.exists()); + + // Cleanup should remove the directory + let summary = tracker.cleanup(); + assert_eq!(summary.dirs_removed, 1); + assert!(!dir_path.exists()); + } + + #[test] + fn test_cleanup_order() { + let temp_dir = TempDir::new().unwrap(); + + // Create nested structure + let parent_dir = temp_dir.path().join("parent"); + let child_dir = parent_dir.join("child"); + let file_path = child_dir.join("file.txt"); + + fs::create_dir(&parent_dir).unwrap(); + fs::create_dir(&child_dir).unwrap(); + fs::File::create(&file_path).unwrap(); + + let mut tracker = CleanupTracker::new(); + tracker.track_dir(parent_dir.clone()); + tracker.track_dir(child_dir.clone()); + tracker.track_file(file_path.clone()); + + // Cleanup should handle nested structure + let summary = tracker.cleanup(); + assert_eq!(summary.files_removed, 1); + assert!(summary.dirs_removed >= 1); // At least child dir should be removed + } +} diff --git a/helix-cli/src/commands/add.rs b/helix-cli/src/commands/add.rs index a85089e87..ad13546c6 100644 --- a/helix-cli/src/commands/add.rs +++ b/helix-cli/src/commands/add.rs @@ -1,3 +1,4 @@ +use crate::cleanup::CleanupTracker; use crate::CloudDeploymentTypeCommand; use crate::commands::integrations::ecr::{EcrAuthType, EcrManager}; use crate::commands::integrations::fly::{FlyAuthType, FlyManager, VmSize}; @@ -11,6 +12,27 @@ use eyre::Result; use std::env; pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { + let mut cleanup_tracker = CleanupTracker::new(); + + // Execute the add logic, capturing any errors + let result = run_add_inner(deployment_type, &mut cleanup_tracker).await; + + // If there was an error, perform cleanup + if let Err(ref e) = result { + if cleanup_tracker.has_tracked_resources() { + eprintln!("Add failed, performing cleanup: {}", e); + let summary = cleanup_tracker.cleanup(); + summary.log_summary(); + } + } + + result +} + +async fn run_add_inner( + deployment_type: CloudDeploymentTypeCommand, + cleanup_tracker: &mut CleanupTracker, +) -> Result<()> { let cwd = env::current_dir()?; let mut project_context = ProjectContext::find_and_load(Some(&cwd))?; @@ -34,6 +56,10 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { &format!("Adding instance '{instance_name}' to Helix project"), ); + // Backup the original config before any modifications + let config_path = project_context.root.join("helix.toml"); + cleanup_tracker.backup_config(&project_context.config, config_path.clone()); + // Determine instance type match deployment_type { diff --git a/helix-cli/src/commands/init.rs b/helix-cli/src/commands/init.rs index 8fe1aceef..9208e88cb 100644 --- a/helix-cli/src/commands/init.rs +++ b/helix-cli/src/commands/init.rs @@ -1,3 +1,4 @@ +use crate::cleanup::CleanupTracker; use crate::CloudDeploymentTypeCommand; use crate::commands::integrations::ecr::{EcrAuthType, EcrManager}; use crate::commands::integrations::fly::{FlyAuthType, FlyManager, VmSize}; @@ -17,6 +18,37 @@ pub async fn run( _template: String, queries_path: String, deployment_type: Option, +) -> Result<()> { + let mut cleanup_tracker = CleanupTracker::new(); + + // Execute the init logic, capturing any errors + let result = run_init_inner( + path, + _template, + queries_path, + deployment_type, + &mut cleanup_tracker, + ) + .await; + + // If there was an error, perform cleanup + if let Err(ref e) = result { + if cleanup_tracker.has_tracked_resources() { + eprintln!("Init failed, performing cleanup: {}", e); + let summary = cleanup_tracker.cleanup(); + summary.log_summary(); + } + } + + result +} + +async fn run_init_inner( + path: Option, + _template: String, + queries_path: String, + deployment_type: Option, + cleanup_tracker: &mut CleanupTracker, ) -> Result<()> { let project_dir = match path { Some(p) => std::path::PathBuf::from(p), @@ -45,14 +77,22 @@ pub async fn run( ); // Create project directory if it doesn't exist + let project_dir_existed = project_dir.exists(); fs::create_dir_all(&project_dir)?; + if !project_dir_existed { + cleanup_tracker.track_dir(project_dir.clone()); + } // Create default helix.toml with custom queries path let mut config = HelixConfig::default_config(project_name); config.project.queries = std::path::PathBuf::from(&queries_path); + + // Save initial config and track it config.save_to_file(&config_path)?; + cleanup_tracker.track_file(config_path.clone()); + // Create project structure - create_project_structure(&project_dir, &queries_path)?; + create_project_structure(&project_dir, &queries_path, cleanup_tracker)?; // Initialize deployment type based on flags @@ -83,6 +123,9 @@ pub async fn run( CloudConfig::Helix(cloud_config.clone()), ); + // Backup config before saving + cleanup_tracker.backup_config(&config, config_path.clone()); + // save config config.save_to_file(&config_path)?; } @@ -116,6 +159,10 @@ pub async fn run( project_name.to_string(), CloudConfig::Ecr(ecr_config.clone()), ); + + // Backup config before saving + cleanup_tracker.backup_config(&config, config_path.clone()); + config.save_to_file(&config_path)?; print_status("ECR", "AWS ECR repository initialized successfully"); @@ -156,6 +203,10 @@ pub async fn run( project_name.to_string(), CloudConfig::FlyIo(instance_config.clone()), ); + + // Backup config before saving + cleanup_tracker.backup_config(&config, config_path.clone()); + config.save_to_file(&config_path)?; } _ => {} @@ -184,10 +235,19 @@ pub async fn run( Ok(()) } -fn create_project_structure(project_dir: &Path, queries_path: &str) -> Result<()> { +fn create_project_structure( + project_dir: &Path, + queries_path: &str, + cleanup_tracker: &mut CleanupTracker, +) -> Result<()> { // Create directories - fs::create_dir_all(project_dir.join(".helix"))?; - fs::create_dir_all(project_dir.join(queries_path))?; + let helix_dir = project_dir.join(".helix"); + fs::create_dir_all(&helix_dir)?; + cleanup_tracker.track_dir(helix_dir); + + let queries_dir = project_dir.join(queries_path); + fs::create_dir_all(&queries_dir)?; + cleanup_tracker.track_dir(queries_dir); // Create default schema.hx with proper Helix syntax let default_schema = r#"// Start building your schema here. @@ -221,10 +281,9 @@ fn create_project_structure(project_dir: &Path, queries_path: &str) -> Result<() // } // } "#; - fs::write( - project_dir.join(queries_path).join("schema.hx"), - default_schema, - )?; + let schema_path = project_dir.join(queries_path).join("schema.hx"); + fs::write(&schema_path, default_schema)?; + cleanup_tracker.track_file(schema_path); // Create default queries.hx with proper Helix query syntax in the queries directory let default_queries = r#"// Start writing your queries here. @@ -246,17 +305,18 @@ fn create_project_structure(project_dir: &Path, queries_path: &str) -> Result<() // see the documentation at https://docs.helix-db.com // or checkout our GitHub at https://github.com/HelixDB/helix-db "#; - fs::write( - project_dir.join(queries_path).join("queries.hx"), - default_queries, - )?; + let queries_path_file = project_dir.join(queries_path).join("queries.hx"); + fs::write(&queries_path_file, default_queries)?; + cleanup_tracker.track_file(queries_path_file); // Create .gitignore let gitignore = r#".helix/ target/ *.log "#; - fs::write(project_dir.join(".gitignore"), gitignore)?; + let gitignore_path = project_dir.join(".gitignore"); + fs::write(&gitignore_path, gitignore)?; + cleanup_tracker.track_file(gitignore_path); Ok(()) } diff --git a/helix-cli/src/lib.rs b/helix-cli/src/lib.rs index 2a0e61531..5a6553a1b 100644 --- a/helix-cli/src/lib.rs +++ b/helix-cli/src/lib.rs @@ -1,6 +1,7 @@ // Library interface for helix-cli to enable testing use clap::Subcommand; +pub mod cleanup; pub mod commands; pub mod config; pub mod docker; diff --git a/helix-cli/src/main.rs b/helix-cli/src/main.rs index 43082630b..f4ad38b9d 100644 --- a/helix-cli/src/main.rs +++ b/helix-cli/src/main.rs @@ -2,6 +2,7 @@ use clap::{Parser, Subcommand}; use eyre::Result; use helix_cli::{AuthAction, CloudDeploymentTypeCommand, MetricsAction}; +mod cleanup; mod commands; mod config; mod docker; From f3c56dc1b0e230e0b573f161b82a7878242cb81f Mon Sep 17 00:00:00 2001 From: Shrey Pant Date: Thu, 20 Nov 2025 01:31:49 +0530 Subject: [PATCH 07/48] run rustfmt --- helix-cli/src/commands/add.rs | 33 +- helix-cli/src/commands/build.rs | 18 +- helix-cli/src/commands/integrations/ecr.rs | 1 - helix-cli/src/commands/migrate.rs | 79 ++-- helix-cli/src/commands/prune.rs | 36 +- helix-cli/src/commands/start.rs | 8 +- helix-cli/src/commands/status.rs | 2 +- helix-cli/src/config.rs | 7 +- helix-cli/src/docker.rs | 422 +++++++++++---------- 9 files changed, 344 insertions(+), 262 deletions(-) diff --git a/helix-cli/src/commands/add.rs b/helix-cli/src/commands/add.rs index 830f7f898..6026be2e7 100644 --- a/helix-cli/src/commands/add.rs +++ b/helix-cli/src/commands/add.rs @@ -42,16 +42,20 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { let helix_manager = HelixManager::new(&project_context); // Create cloud instance configuration - let cloud_config = helix_manager.create_instance_config(&instance_name, region).await?; + let cloud_config = helix_manager + .create_instance_config(&instance_name, region) + .await?; // Initialize the cloud cluster - helix_manager.init_cluster(&instance_name, &cloud_config).await?; + helix_manager + .init_cluster(&instance_name, &cloud_config) + .await?; // Insert into project configuration - project_context - .config - .cloud - .insert(instance_name.clone(), CloudConfig::Helix(cloud_config.clone())); + project_context.config.cloud.insert( + instance_name.clone(), + CloudConfig::Helix(cloud_config.clone()), + ); print_status("CLOUD", "Helix cloud instance configuration added"); } @@ -70,7 +74,9 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { .await?; // Initialize the ECR repository - ecr_manager.init_repository(&instance_name, &ecr_config).await?; + ecr_manager + .init_repository(&instance_name, &ecr_config) + .await?; // Save configuration to ecr.toml ecr_manager.save_config(&instance_name, &ecr_config).await?; @@ -110,12 +116,14 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { ); // Initialize the Fly.io app - fly_manager.init_app(&instance_name, &instance_config).await?; + fly_manager + .init_app(&instance_name, &instance_config) + .await?; - project_context - .config - .cloud - .insert(instance_name.clone(), CloudConfig::FlyIo(instance_config.clone())); + project_context.config.cloud.insert( + instance_name.clone(), + CloudConfig::FlyIo(instance_config.clone()), + ); } _ => { // Add local instance with default configuration @@ -123,7 +131,6 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { port: None, // Let the system assign a port build_mode: BuildMode::Debug, db_config: DbConfig::default(), - }; project_context diff --git a/helix-cli/src/commands/build.rs b/helix-cli/src/commands/build.rs index cca93c3d3..a0e887294 100644 --- a/helix-cli/src/commands/build.rs +++ b/helix-cli/src/commands/build.rs @@ -89,7 +89,7 @@ pub async fn run(instance_name: String, metrics_sender: &MetricsSender) -> Resul let runtime = project.config.project.container_runtime; DockerManager::check_runtime_available(runtime)?; let docker = DockerManager::new(&project); - + docker.build_image(&instance_name, instance_config.docker_build_target())?; } @@ -121,13 +121,19 @@ fn needs_cache_recreation(repo_cache: &std::path::Path) -> Result { match (DEV_MODE, is_git_repo) { (true, true) => { - print_status("CACHE", "Cache is git repo but DEV_MODE requires copy - recreating..."); + print_status( + "CACHE", + "Cache is git repo but DEV_MODE requires copy - recreating...", + ); Ok(true) - }, + } (false, false) => { - print_status("CACHE", "Cache is copy but production mode requires git repo - recreating..."); + print_status( + "CACHE", + "Cache is copy but production mode requires git repo - recreating...", + ); Ok(true) - }, + } _ => Ok(false), } } @@ -292,8 +298,6 @@ async fn generate_docker_files( return Ok(()); } - - let docker = DockerManager::new(project); print_status(docker.runtime.label(), "Generating configuration..."); diff --git a/helix-cli/src/commands/integrations/ecr.rs b/helix-cli/src/commands/integrations/ecr.rs index f31d2e9d7..1377e5a81 100644 --- a/helix-cli/src/commands/integrations/ecr.rs +++ b/helix-cli/src/commands/integrations/ecr.rs @@ -92,7 +92,6 @@ impl<'a> EcrManager<'a> { format!("helix-{}-{instance_name}", self.project.config.project.name) } - fn image_name(&self, repository_name: &str, build_mode: BuildMode) -> String { let tag = match build_mode { BuildMode::Debug => "debug", diff --git a/helix-cli/src/commands/migrate.rs b/helix-cli/src/commands/migrate.rs index 7e2ecb92c..3febdfebc 100644 --- a/helix-cli/src/commands/migrate.rs +++ b/helix-cli/src/commands/migrate.rs @@ -1,5 +1,6 @@ use crate::config::{ - BuildMode, DbConfig, GraphConfig, HelixConfig, LocalInstanceConfig, ProjectConfig, VectorConfig, ContainerRuntime + BuildMode, ContainerRuntime, DbConfig, GraphConfig, HelixConfig, LocalInstanceConfig, + ProjectConfig, VectorConfig, }; use crate::errors::{CliError, project_error}; use crate::utils::{ @@ -231,7 +232,10 @@ fn find_hx_files(project_dir: &Path) -> Result> { let entry = entry?; let path = entry.path(); - if let Some(extension) = path.extension() && extension == "hx" && path.file_name() != Some("schema.hx".as_ref()) { + if let Some(extension) = path.extension() + && extension == "hx" + && path.file_name() != Some("schema.hx".as_ref()) + { hx_files.push(path); } } @@ -274,18 +278,28 @@ fn show_migration_plan(ctx: &MigrationContext) -> Result<()> { print_newline(); print_header("🏠 Home Directory Migration:"); - let home_dir = dirs::home_dir().ok_or_else(|| CliError::new("Could not find home directory"))?; + let home_dir = + dirs::home_dir().ok_or_else(|| CliError::new("Could not find home directory"))?; let v1_helix_dir = home_dir.join(".helix"); if v1_helix_dir.exists() { let v2_marker = v1_helix_dir.join(".v2"); if v2_marker.exists() { - print_field("Already migrated", "~/.helix directory already migrated to v2"); + print_field( + "Already migrated", + "~/.helix directory already migrated to v2", + ); } else { print_field("Create backup", "~/.helix → ~/.helix-v1-backup"); if v1_helix_dir.join("dockerdev").exists() { - print_field("Clean up Docker", "Stop/remove helix-dockerdev containers and images"); + print_field( + "Clean up Docker", + "Stop/remove helix-dockerdev containers and images", + ); } - print_field("Clean directory", "Remove all except ~/.helix/credentials and ~/.helix/repo"); + print_field( + "Clean directory", + "Remove all except ~/.helix/credentials and ~/.helix/repo", + ); if v1_helix_dir.join("credentials").exists() { print_field("Preserve file", "~/.helix/credentials"); } @@ -418,7 +432,6 @@ fn create_v2_config(ctx: &MigrationContext) -> Result<()> { port: Some(ctx.port), build_mode: BuildMode::Debug, db_config, - }; // Create local instances map @@ -502,9 +515,8 @@ fn provide_post_migration_guidance(ctx: &MigrationContext) -> Result<()> { fn migrate_home_directory(_ctx: &MigrationContext) -> Result<()> { print_status("HOME", "Migrating ~/.helix directory"); - let home_dir = dirs::home_dir().ok_or_else(|| { - CliError::new("Could not find home directory") - })?; + let home_dir = + dirs::home_dir().ok_or_else(|| CliError::new("Could not find home directory"))?; let v1_helix_dir = home_dir.join(".helix"); @@ -531,8 +543,7 @@ fn migrate_home_directory(_ctx: &MigrationContext) -> Result<()> { // Use the utility function to copy the directory without exclusions crate::utils::copy_dir_recursively(&v1_helix_dir, &backup_dir).map_err(|e| { - CliError::new("Failed to backup ~/.helix directory") - .with_caused_by(e.to_string()) + CliError::new("Failed to backup ~/.helix directory").with_caused_by(e.to_string()) })?; print_success("Created backup: ~/.helix-v1-backup"); @@ -551,8 +562,7 @@ fn migrate_home_directory(_ctx: &MigrationContext) -> Result<()> { let temp_credentials = if credentials_path.exists() { let temp_path = home_dir.join(".helix-credentials-temp"); fs::rename(&credentials_path, &temp_path).map_err(|e| { - CliError::new("Failed to backup credentials") - .with_caused_by(e.to_string()) + CliError::new("Failed to backup credentials").with_caused_by(e.to_string()) })?; Some(temp_path) } else { @@ -561,10 +571,8 @@ fn migrate_home_directory(_ctx: &MigrationContext) -> Result<()> { let temp_repo = if repo_path.exists() { let temp_path = home_dir.join(".helix-repo-temp"); - fs::rename(&repo_path, &temp_path).map_err(|e| { - CliError::new("Failed to backup repo") - .with_caused_by(e.to_string()) - })?; + fs::rename(&repo_path, &temp_path) + .map_err(|e| CliError::new("Failed to backup repo").with_caused_by(e.to_string()))?; Some(temp_path) } else { None @@ -572,37 +580,31 @@ fn migrate_home_directory(_ctx: &MigrationContext) -> Result<()> { // Remove the entire .helix directory fs::remove_dir_all(&v1_helix_dir).map_err(|e| { - CliError::new("Failed to remove ~/.helix directory") - .with_caused_by(e.to_string()) + CliError::new("Failed to remove ~/.helix directory").with_caused_by(e.to_string()) })?; // Recreate .helix directory fs::create_dir_all(&v1_helix_dir).map_err(|e| { - CliError::new("Failed to recreate ~/.helix directory") - .with_caused_by(e.to_string()) + CliError::new("Failed to recreate ~/.helix directory").with_caused_by(e.to_string()) })?; // Restore credentials and repo if let Some(temp_creds) = temp_credentials { fs::rename(&temp_creds, &credentials_path).map_err(|e| { - CliError::new("Failed to restore credentials") - .with_caused_by(e.to_string()) + CliError::new("Failed to restore credentials").with_caused_by(e.to_string()) })?; print_info("Preserved ~/.helix/credentials"); } if let Some(temp_repo) = temp_repo { - fs::rename(&temp_repo, &repo_path).map_err(|e| { - CliError::new("Failed to restore repo") - .with_caused_by(e.to_string()) - })?; + fs::rename(&temp_repo, &repo_path) + .map_err(|e| CliError::new("Failed to restore repo").with_caused_by(e.to_string()))?; print_info("Preserved ~/.helix/repo"); } // Create .v2 marker file to indicate migration is complete fs::write(&v2_marker, "").map_err(|e| { - CliError::new("Failed to create v2 marker file") - .with_caused_by(e.to_string()) + CliError::new("Failed to create v2 marker file").with_caused_by(e.to_string()) })?; print_success("Cleaned up ~/.helix directory, preserving credentials and repo"); @@ -627,7 +629,13 @@ fn cleanup_dockerdev() -> Result<()> { // Try to remove any helix-related images let output = std::process::Command::new("docker") - .args(["images", "--format", "{{.Repository}}:{{.Tag}}", "--filter", "reference=helix*"]) + .args([ + "images", + "--format", + "{{.Repository}}:{{.Tag}}", + "--filter", + "reference=helix*", + ]) .output(); if let Ok(output) = output { @@ -641,7 +649,14 @@ fn cleanup_dockerdev() -> Result<()> { // Try to remove helix volumes let output = std::process::Command::new("docker") - .args(["volume", "ls", "--format", "{{.Name}}", "--filter", "name=helix"]) + .args([ + "volume", + "ls", + "--format", + "{{.Name}}", + "--filter", + "name=helix", + ]) .output(); if let Ok(output) = output { diff --git a/helix-cli/src/commands/prune.rs b/helix-cli/src/commands/prune.rs index 8becd2186..fca1f3b90 100644 --- a/helix-cli/src/commands/prune.rs +++ b/helix-cli/src/commands/prune.rs @@ -1,3 +1,4 @@ +use crate::config::ContainerRuntime; use crate::docker::DockerManager; use crate::errors::project_error; use crate::project::ProjectContext; @@ -5,7 +6,6 @@ use crate::utils::{ print_confirm, print_lines, print_newline, print_status, print_success, print_warning, }; use eyre::Result; -use crate::config::ContainerRuntime; pub async fn run(instance: Option, all: bool) -> Result<()> { // Try to load project context @@ -73,18 +73,27 @@ async fn prune_all_instances(project: &ProjectContext) -> Result<()> { return Ok(()); } - print_status("PRUNE", &format!("Found {} instance(s) to prune", instances.len())); + print_status( + "PRUNE", + &format!("Found {} instance(s) to prune", instances.len()), + ); let runtime = project.config.project.container_runtime; if DockerManager::check_runtime_available(runtime).is_ok() { let docker = DockerManager::new(project); for instance_name in &instances { - print_status("PRUNE", &format!("Removing containers for '{instance_name}'")); + print_status( + "PRUNE", + &format!("Removing containers for '{instance_name}'"), + ); // Remove containers (but not volumes) let _ = docker.prune_instance(instance_name, false); - print_status("PRUNE", &format!("Removing Docker images for '{instance_name}'")); + print_status( + "PRUNE", + &format!("Removing Docker images for '{instance_name}'"), + ); // Remove Docker images let _ = docker.remove_instance_images(instance_name); } @@ -95,8 +104,12 @@ async fn prune_all_instances(project: &ProjectContext) -> Result<()> { let workspace = project.instance_workspace(instance_name); if workspace.exists() { match std::fs::remove_dir_all(&workspace) { - Ok(()) => print_status("PRUNE", &format!("Removed workspace for '{instance_name}'")), - Err(e) => print_warning(&format!("Failed to remove workspace for '{instance_name}': {e}")), + Ok(()) => { + print_status("PRUNE", &format!("Removed workspace for '{instance_name}'")) + } + Err(e) => print_warning(&format!( + "Failed to remove workspace for '{instance_name}': {e}" + )), } } } @@ -162,15 +175,18 @@ async fn prune_system_wide() -> Result<()> { if DockerManager::check_runtime_available(runtime).is_ok() { DockerManager::clean_all_helix_images(runtime)?; // Run system prune for this runtime - let output = std::process::Command::new(runtime.binary()) + let output = std::process::Command::new(runtime.binary()) .args(["system", "prune", "-f"]) .output() .map_err(|e| eyre::eyre!("Failed to run {} system prune: {e}", runtime.binary()))?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - print_warning(&format!("{} system prune failed: {stderr}", runtime.label())); - } + print_warning(&format!( + "{} system prune failed: {stderr}", + runtime.label() + )); + } } } diff --git a/helix-cli/src/commands/start.rs b/helix-cli/src/commands/start.rs index 9335f124d..6d5b16e4a 100644 --- a/helix-cli/src/commands/start.rs +++ b/helix-cli/src/commands/start.rs @@ -35,8 +35,12 @@ async fn start_local_instance(project: &ProjectContext, instance_name: &str) -> let compose_file = workspace.join("docker-compose.yml"); if !compose_file.exists() { - let error = crate::errors::CliError::new(format!("instance '{instance_name}' has not been built yet")) - .with_hint(format!("run 'helix build {instance_name}' first to build the instance")); + let error = crate::errors::CliError::new(format!( + "instance '{instance_name}' has not been built yet" + )) + .with_hint(format!( + "run 'helix build {instance_name}' first to build the instance" + )); return Err(eyre::eyre!("{}", error.render())); } diff --git a/helix-cli/src/commands/status.rs b/helix-cli/src/commands/status.rs index fd1b62e33..a6bf65517 100644 --- a/helix-cli/src/commands/status.rs +++ b/helix-cli/src/commands/status.rs @@ -76,7 +76,7 @@ pub async fn run() -> Result<()> { async fn show_container_status(project: &ProjectContext) -> Result<()> { // Check if Docker is available - let runtime = project.config.project.container_runtime ; + let runtime = project.config.project.container_runtime; if DockerManager::check_runtime_available(runtime).is_err() { print_field("Docker Status", "Not available"); return Ok(()); diff --git a/helix-cli/src/config.rs b/helix-cli/src/config.rs index 89c112250..c04681ff2 100644 --- a/helix-cli/src/config.rs +++ b/helix-cli/src/config.rs @@ -106,7 +106,10 @@ pub struct DbConfig { pub bm25: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub schema: Option, - #[serde(default = "default_embedding_model", skip_serializing_if = "is_default_embedding_model")] + #[serde( + default = "default_embedding_model", + skip_serializing_if = "is_default_embedding_model" + )] pub embedding_model: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub graphvis_node_label: Option, @@ -120,7 +123,6 @@ pub struct LocalInstanceConfig { pub build_mode: BuildMode, #[serde(flatten)] pub db_config: DbConfig, - } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -441,7 +443,6 @@ impl HelixConfig { port: Some(6969), build_mode: BuildMode::Debug, db_config: DbConfig::default(), - }, ); diff --git a/helix-cli/src/docker.rs b/helix-cli/src/docker.rs index 54aebc912..a05f5ba0b 100644 --- a/helix-cli/src/docker.rs +++ b/helix-cli/src/docker.rs @@ -3,7 +3,7 @@ //! Despite the module name, this works with both Docker and Podman as they //! share the same CLI interface and support standard Dockerfile formats. -use crate::config::{BuildMode, InstanceInfo,ContainerRuntime}; +use crate::config::{BuildMode, ContainerRuntime, InstanceInfo}; use crate::project::ProjectContext; use crate::utils::{print_confirm, print_status, print_warning}; use eyre::{Result, eyre}; @@ -14,17 +14,17 @@ use std::time::Duration; pub struct DockerManager<'a> { project: &'a ProjectContext, /// The container runtime to use (Docker or Podman) - pub(crate) runtime: ContainerRuntime, + pub(crate) runtime: ContainerRuntime, } - pub const HELIX_DATA_DIR: &str = "/data"; impl<'a> DockerManager<'a> { pub fn new(project: &'a ProjectContext) -> Self { - Self { project , - runtime : project.config.project.container_runtime, - } + Self { + project, + runtime: project.config.project.container_runtime, + } } // === CENTRALIZED NAMING METHODS === @@ -109,7 +109,13 @@ impl<'a> DockerManager<'a> { let output = Command::new(self.runtime.binary()) .args(args) .output() - .map_err(|e| eyre!("Failed to run {} {}: {e}",self.runtime.binary(), args.join(" ")))?; + .map_err(|e| { + eyre!( + "Failed to run {} {}: {e}", + self.runtime.binary(), + args.join(" ") + ) + })?; Ok(output) } @@ -126,7 +132,13 @@ impl<'a> DockerManager<'a> { .args(&full_args) .current_dir(&workspace) .output() - .map_err(|e| eyre!("Failed to run {} compose {}: {e}", self.runtime.binary() , full_args.join(" ")))?; + .map_err(|e| { + eyre!( + "Failed to run {} compose {}: {e}", + self.runtime.binary(), + full_args.join(" ") + ) + })?; Ok(output) } @@ -146,95 +158,95 @@ impl<'a> DockerManager<'a> { } /// Start the container runtime daemon based on platform and runtime -fn start_runtime_daemon(runtime: ContainerRuntime) -> Result<()> { - let platform = Self::detect_platform(); - - match (runtime, platform) { - // Docker on macOS - (ContainerRuntime::Docker, "macos") => { - print_status("DOCKER", "Starting Docker Desktop for macOS..."); - Command::new("open") - .args(["-a", "Docker"]) - .output() - .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; - } + fn start_runtime_daemon(runtime: ContainerRuntime) -> Result<()> { + let platform = Self::detect_platform(); + + match (runtime, platform) { + // Docker on macOS + (ContainerRuntime::Docker, "macos") => { + print_status("DOCKER", "Starting Docker Desktop for macOS..."); + Command::new("open") + .args(["-a", "Docker"]) + .output() + .map_err(|e| eyre!("Failed to start Docker Desktop: {}", e))?; + } - // Podman on macOS - (ContainerRuntime::Podman, "macos") => { - print_status("PODMAN", "Starting Podman machine on macOS..."); - - // Check if machine exists first - let list_output = Command::new("podman") - .args(["machine", "list", "--format", "{{.Name}}"]) - .output() - .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; - - let machines = String::from_utf8_lossy(&list_output.stdout); - - if machines.trim().is_empty() { - // No machine exists, initialize one - print_status("PODMAN", "Initializing Podman machine (first time)..."); - let init_output = Command::new("podman") - .args(["machine", "init"]) + // Podman on macOS + (ContainerRuntime::Podman, "macos") => { + print_status("PODMAN", "Starting Podman machine on macOS..."); + + // Check if machine exists first + let list_output = Command::new("podman") + .args(["machine", "list", "--format", "{{.Name}}"]) .output() - .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; - - if !init_output.status.success() { - let stderr = String::from_utf8_lossy(&init_output.stderr); - return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; + + let machines = String::from_utf8_lossy(&list_output.stdout); + + if machines.trim().is_empty() { + // No machine exists, initialize one + print_status("PODMAN", "Initializing Podman machine (first time)..."); + let init_output = Command::new("podman") + .args(["machine", "init"]) + .output() + .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; + + if !init_output.status.success() { + let stderr = String::from_utf8_lossy(&init_output.stderr); + return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + } } + + // Start the machine + Command::new("podman") + .args(["machine", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; } - - // Start the machine - Command::new("podman") - .args(["machine", "start"]) - .output() - .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; - } - // Docker on Linux - (ContainerRuntime::Docker, "linux") => { - print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); - let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); + // Docker on Linux + (ContainerRuntime::Docker, "linux") => { + print_status("DOCKER", "Attempting to start Docker daemon on Linux..."); + let systemctl_result = Command::new("systemctl").args(["start", "docker"]).output(); - match systemctl_result { - Ok(output) if output.status.success() => {} - _ => { - let service_result = Command::new("service") - .args(["docker", "start"]) - .output() - .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; + match systemctl_result { + Ok(output) if output.status.success() => {} + _ => { + let service_result = Command::new("service") + .args(["docker", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Docker daemon: {}", e))?; - if !service_result.status.success() { - let stderr = String::from_utf8_lossy(&service_result.stderr); - return Err(eyre!("Failed to start Docker daemon: {}", stderr)); + if !service_result.status.success() { + let stderr = String::from_utf8_lossy(&service_result.stderr); + return Err(eyre!("Failed to start Docker daemon: {}", stderr)); + } } } } - } - // Podman on Linux - (ContainerRuntime::Podman, "linux") => { - print_status("PODMAN", "Starting Podman service on Linux..."); - - // Try to start user service (rootless) - let user_service = Command::new("systemctl") - .args(["--user", "start", "podman.socket"]) - .output(); - - if user_service.is_err() || !user_service.unwrap().status.success() { - // Try system service (rootful) as fallback - let system_service = Command::new("systemctl") - .args(["start", "podman.socket"]) + // Podman on Linux + (ContainerRuntime::Podman, "linux") => { + print_status("PODMAN", "Starting Podman service on Linux..."); + + // Try to start user service (rootless) + let user_service = Command::new("systemctl") + .args(["--user", "start", "podman.socket"]) .output(); - - if let Err(e) = system_service { - print_warning(&format!("Could not start Podman service: {}", e)); + + if user_service.is_err() || !user_service.unwrap().status.success() { + // Try system service (rootful) as fallback + let system_service = Command::new("systemctl") + .args(["start", "podman.socket"]) + .output(); + + if let Err(e) = system_service { + print_warning(&format!("Could not start Podman service: {}", e)); + } } } - } - // Docker on Windows - (ContainerRuntime::Docker, "windows") => { + // Docker on Windows + (ContainerRuntime::Docker, "windows") => { print_status("DOCKER", "Starting Docker Desktop for Windows..."); // Try Docker Desktop CLI (4.37+) first let cli_result = Command::new("docker").args(["desktop", "start"]).output(); @@ -258,137 +270,146 @@ fn start_runtime_daemon(runtime: ContainerRuntime) -> Result<()> { } } } - - - // Podman on Windows - (ContainerRuntime::Podman, "windows") => { - print_status("PODMAN", "Starting Podman machine on Windows..."); - - // Check if machine exists - let list_output = Command::new("podman") - .args(["machine", "list", "--format", "{{.Name}}"]) - .output() - .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; - - let machines = String::from_utf8_lossy(&list_output.stdout); - - if machines.trim().is_empty() { - // Initialize machine first - print_status("PODMAN", "Initializing Podman machine (first time)..."); - let init_output = Command::new("podman") - .args(["machine", "init"]) + + // Podman on Windows + (ContainerRuntime::Podman, "windows") => { + print_status("PODMAN", "Starting Podman machine on Windows..."); + + // Check if machine exists + let list_output = Command::new("podman") + .args(["machine", "list", "--format", "{{.Name}}"]) .output() - .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; - - if !init_output.status.success() { - let stderr = String::from_utf8_lossy(&init_output.stderr); - return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + .map_err(|e| eyre!("Failed to list Podman machines: {}", e))?; + + let machines = String::from_utf8_lossy(&list_output.stdout); + + if machines.trim().is_empty() { + // Initialize machine first + print_status("PODMAN", "Initializing Podman machine (first time)..."); + let init_output = Command::new("podman") + .args(["machine", "init"]) + .output() + .map_err(|e| eyre!("Failed to initialize Podman machine: {}", e))?; + + if !init_output.status.success() { + let stderr = String::from_utf8_lossy(&init_output.stderr); + return Err(eyre!("Failed to initialize Podman machine: {}", stderr)); + } + } + + // Start the machine + let start_output = Command::new("podman") + .args(["machine", "start"]) + .output() + .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; + + if !start_output.status.success() { + let stderr = String::from_utf8_lossy(&start_output.stderr); + return Err(eyre!("Failed to start Podman machine: {}", stderr)); } } - - // Start the machine - let start_output = Command::new("podman") - .args(["machine", "start"]) - .output() - .map_err(|e| eyre!("Failed to start Podman machine: {}", e))?; - - if !start_output.status.success() { - let stderr = String::from_utf8_lossy(&start_output.stderr); - return Err(eyre!("Failed to start Podman machine: {}", stderr)); + + (_, platform) => { + return Err(eyre!( + "Unsupported platform '{}' for auto-starting {}", + platform, + runtime.label() + )); } } - (_, platform) => { - return Err(eyre!("Unsupported platform '{}' for auto-starting {}", platform, runtime.label())); - } + Ok(()) } - Ok(()) -} + fn wait_for_runtime(runtime: ContainerRuntime, timeout_secs: u64) -> Result<()> { + print_status(runtime.label(), "Waiting for daemon to start..."); -fn wait_for_runtime(runtime: ContainerRuntime, timeout_secs: u64) -> Result<()> { - print_status(runtime.label(), "Waiting for daemon to start..."); + let start = std::time::Instant::now(); + let timeout = Duration::from_secs(timeout_secs); - let start = std::time::Instant::now(); - let timeout = Duration::from_secs(timeout_secs); + while start.elapsed() < timeout { + let output = Command::new(runtime.binary()).args(["info"]).output(); - while start.elapsed() < timeout { - let output = Command::new(runtime.binary()).args(["info"]).output(); + if let Ok(output) = output + && output.status.success() + { + print_status(runtime.label(), "Daemon is now running"); + return Ok(()); + } - if let Ok(output) = output && output.status.success() { - print_status(runtime.label(), "Daemon is now running"); - return Ok(()); + thread::sleep(Duration::from_millis(500)); } - thread::sleep(Duration::from_millis(500)); + Err(eyre!( + "Timeout waiting for {} daemon to start. Please start {} manually and try again.", + runtime.label(), + runtime.binary() + )) } - Err(eyre!( - "Timeout waiting for {} daemon to start. Please start {} manually and try again.", - runtime.label(), - runtime.binary() - )) -} - /// Check if container runtime is installed and running, with auto-start option -pub fn check_runtime_available(runtime: ContainerRuntime) -> Result<()> { - let cmd = runtime.binary(); - - let output = Command::new(cmd) - .args(["--version"]) - .output() - .map_err(|_| eyre!("{} is not installed or not available in PATH", cmd))?; - - if !output.status.success() { - return Err(eyre!("{} is installed but not working properly", cmd)); - } + pub fn check_runtime_available(runtime: ContainerRuntime) -> Result<()> { + let cmd = runtime.binary(); - // Check if daemon is running - let output = Command::new(cmd) - .args(["info"]) - .output() - .map_err(|_| eyre!("Failed to check {} daemon status", cmd))?; + let output = Command::new(cmd) + .args(["--version"]) + .output() + .map_err(|_| eyre!("{} is not installed or not available in PATH", cmd))?; - if !output.status.success() { - // Daemon not running - ask user if they want to start it - let message = format!( - "{} daemon is not running. Would you like to start {}?", - runtime.label(), - runtime.binary() - ); - let should_start = print_confirm(&message).unwrap_or(false); + if !output.status.success() { + return Err(eyre!("{} is installed but not working properly", cmd)); + } - if should_start { - // Try to start the runtime - Self::start_runtime_daemon(runtime)?; + // Check if daemon is running + let output = Command::new(cmd) + .args(["info"]) + .output() + .map_err(|_| eyre!("Failed to check {} daemon status", cmd))?; - // Wait for it to be ready - Self::wait_for_runtime(runtime, 15)?; + if !output.status.success() { + // Daemon not running - ask user if they want to start it + let message = format!( + "{} daemon is not running. Would you like to start {}?", + runtime.label(), + runtime.binary() + ); + let should_start = print_confirm(&message).unwrap_or(false); - // Verify it's running now - let verify_output = Command::new(cmd) - .args(["info"]) - .output() - .map_err(|_| eyre!("Failed to verify {} daemon status", cmd))?; + if should_start { + // Try to start the runtime + Self::start_runtime_daemon(runtime)?; - if !verify_output.status.success() { + // Wait for it to be ready + Self::wait_for_runtime(runtime, 15)?; + + // Verify it's running now + let verify_output = Command::new(cmd) + .args(["info"]) + .output() + .map_err(|_| eyre!("Failed to verify {} daemon status", cmd))?; + + if !verify_output.status.success() { + return Err(eyre!( + "{} daemon failed to start. Please start {} manually and try again.", + runtime.label(), + cmd + )); + } + } else { + print_warning(&format!( + "{} daemon must be running to execute this command.", + runtime.label() + )); return Err(eyre!( - "{} daemon failed to start. Please start {} manually and try again.", - runtime.label(), + "{} daemon is not running. Please start {}.", + cmd, cmd )); } - } else { - print_warning(&format!( - "{} daemon must be running to execute this command.", - runtime.label() - )); - return Err(eyre!("{} daemon is not running. Please start {}.", cmd, cmd)); } - } - Ok(()) -} + Ok(()) + } /// Generate Dockerfile for an instance pub fn generate_dockerfile( @@ -530,7 +551,7 @@ networks: if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(eyre!("{} build failed:\n{stderr}",self.runtime.binary())); + return Err(eyre!("{} build failed:\n{stderr}", self.runtime.binary())); } print_status(self.runtime.label(), "Image built successfully"); @@ -539,7 +560,10 @@ networks: /// Start instance using docker/podman compose pub fn start_instance(&self, instance_name: &str) -> Result<()> { - print_status(self.runtime.label(), &format!("Starting instance '{instance_name}'...")); + print_status( + self.runtime.label(), + &format!("Starting instance '{instance_name}'..."), + ); let output = self.run_compose_command(instance_name, vec!["up", "-d"])?; @@ -557,7 +581,10 @@ networks: /// Stop instance using docker/podman compose pub fn stop_instance(&self, instance_name: &str) -> Result<()> { - print_status(self.runtime.label(), &format!("Stopping instance '{instance_name}'...")); + print_status( + self.runtime.label(), + &format!("Stopping instance '{instance_name}'..."), + ); let output = self.run_compose_command(instance_name, vec!["down"])?; @@ -642,7 +669,10 @@ networks: /// Remove instance containers and optionally volumes pub fn prune_instance(&self, instance_name: &str, remove_volumes: bool) -> Result<()> { - print_status(self.runtime.label(), &format!("Pruning instance '{instance_name}'...")); + print_status( + self.runtime.label(), + &format!("Pruning instance '{instance_name}'..."), + ); // Check if workspace exists - if not, there's nothing to prune let workspace = self.project.instance_workspace(instance_name); @@ -707,7 +737,7 @@ networks: let _ = self.run_docker_command(&["volume", "rm", &volume_to_remove]); } - Ok(()) + Ok(()) } /// Remove Docker/Podman images associated with an instance @@ -743,7 +773,7 @@ networks: "reference=helix-*", ]) .output() - .map_err(|e| eyre!("Failed to list {} images: {e}",runtime.binary()))?; + .map_err(|e| eyre!("Failed to list {} images: {e}", runtime.binary()))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); @@ -772,7 +802,10 @@ networks: } let count = images.len(); - print_status(runtime.label(), &format!("Found {count} Helix images to remove")); + print_status( + runtime.label(), + &format!("Found {count} Helix images to remove"), + ); for image in images { let output = Command::new(runtime.binary()) @@ -807,7 +840,10 @@ networks: pub fn push(&self, image_name: &str, registry_url: &str) -> Result<()> { let registry_image = format!("{registry_url}/{image_name}"); - print_status(self.runtime.label(), &format!("Pushing image: {registry_image}")); + print_status( + self.runtime.label(), + &format!("Pushing image: {registry_image}"), + ); let output = Command::new(self.runtime.binary()) .arg("push") .arg(®istry_image) From f8e4442c8f5e501cae13657796ba8595a06cfdec Mon Sep 17 00:00:00 2001 From: pantShrey <121197985+pantShrey@users.noreply.github.com> Date: Thu, 20 Nov 2025 01:54:05 +0530 Subject: [PATCH 08/48] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- helix-cli/src/commands/prune.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helix-cli/src/commands/prune.rs b/helix-cli/src/commands/prune.rs index fca1f3b90..1dfcb890a 100644 --- a/helix-cli/src/commands/prune.rs +++ b/helix-cli/src/commands/prune.rs @@ -132,7 +132,7 @@ async fn prune_unused_resources(project: &ProjectContext) -> Result<()> { print_newline(); let runtime = project.config.project.container_runtime; // Check Docker availability - print_status("PRUNE", "Checking Docker availability"); + print_status("PRUNE", "Checking container runtime availability"); DockerManager::check_runtime_available(runtime)?; print_status("PRUNE", "Running Docker system cleanup"); From 2dae43f1303af7cb8fdf2afbc07b61d5dbc4d586 Mon Sep 17 00:00:00 2001 From: Shrey Pant Date: Thu, 20 Nov 2025 02:13:58 +0530 Subject: [PATCH 09/48] remove unwrap and status fix --- helix-cli/src/commands/status.rs | 2 +- helix-cli/src/docker.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/helix-cli/src/commands/status.rs b/helix-cli/src/commands/status.rs index a6bf65517..d1ad8759c 100644 --- a/helix-cli/src/commands/status.rs +++ b/helix-cli/src/commands/status.rs @@ -78,7 +78,7 @@ async fn show_container_status(project: &ProjectContext) -> Result<()> { // Check if Docker is available let runtime = project.config.project.container_runtime; if DockerManager::check_runtime_available(runtime).is_err() { - print_field("Docker Status", "Not available"); + print_field(&format!("{} Status", runtime.label()), "Not available"); return Ok(()); } diff --git a/helix-cli/src/docker.rs b/helix-cli/src/docker.rs index a05f5ba0b..007a6f902 100644 --- a/helix-cli/src/docker.rs +++ b/helix-cli/src/docker.rs @@ -234,7 +234,8 @@ impl<'a> DockerManager<'a> { .args(["--user", "start", "podman.socket"]) .output(); - if user_service.is_err() || !user_service.unwrap().status.success() { + // Only skip fallback if command succeeded AND status is success + if !user_service.is_ok_and(|output| output.status.success()) { // Try system service (rootful) as fallback let system_service = Command::new("systemctl") .args(["start", "podman.socket"]) From eb58bba6358fd1f6df633037b2239eb906e0a2db Mon Sep 17 00:00:00 2001 From: xav-db Date: Wed, 19 Nov 2025 13:22:16 -0800 Subject: [PATCH 10/48] fixing issues with exists --- .../analyzer/methods/infer_expr_type.rs | 38 +++++++++------ helix-db/src/helixc/generator/bool_ops.rs | 17 ++++++- hql-tests/test.sh | 4 +- .../tests/cloud_queries_2/config.hx.json | 14 ++++++ hql-tests/tests/cloud_queries_2/helix.toml | 9 ++++ hql-tests/tests/cloud_queries_2/queries.hx | 5 ++ hql-tests/tests/cloud_queries_2/schema.hx | 46 +++++++++++++++++++ 7 files changed, 114 insertions(+), 19 deletions(-) create mode 100644 hql-tests/tests/cloud_queries_2/config.hx.json create mode 100644 hql-tests/tests/cloud_queries_2/helix.toml create mode 100644 hql-tests/tests/cloud_queries_2/queries.hx create mode 100644 hql-tests/tests/cloud_queries_2/schema.hx diff --git a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs index 93ea9fe20..312b531cf 100644 --- a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs +++ b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs @@ -1274,21 +1274,29 @@ pub(crate) fn infer_expr_type<'a>( assert!(matches!(stmt, Some(GeneratedStatement::Traversal(_)))); let traversal = match stmt.unwrap() { GeneratedStatement::Traversal(mut tr) => { - let source_variable = match tr.source_step.inner() { - SourceStep::Identifier(id) => id.inner().clone(), - _ => DEFAULT_VAR_NAME.to_string(), - }; - // Check if the variable is single or plural to determine traversal type - let is_single = scope - .get(source_variable.as_str()) - .map(|var_info| var_info.is_single) - .unwrap_or(false); - - tr.traversal_type = if is_single { - TraversalType::FromSingle(GenRef::Std(source_variable)) - } else { - TraversalType::FromIter(GenRef::Std(source_variable)) - }; + // Only modify traversal_type if source is Identifier or Anonymous + match tr.source_step.inner() { + SourceStep::Identifier(id) => { + let source_variable = id.inner().clone(); + // Check if the variable is single or plural to determine traversal type + let is_single = scope + .get(source_variable.as_str()) + .map(|var_info| var_info.is_single) + .unwrap_or(false); + + tr.traversal_type = if is_single { + TraversalType::FromSingle(GenRef::Std(source_variable)) + } else { + TraversalType::FromIter(GenRef::Std(source_variable)) + }; + } + SourceStep::Anonymous => { + tr.traversal_type = TraversalType::FromSingle(GenRef::Std(DEFAULT_VAR_NAME.to_string())); + } + _ => { + // For AddN, AddV, AddE, SearchVector, etc., leave traversal_type unchanged (Ref) + } + } tr.should_collect = ShouldCollect::No; tr } diff --git a/helix-db/src/helixc/generator/bool_ops.rs b/helix-db/src/helixc/generator/bool_ops.rs index c5b61d469..62f84370f 100644 --- a/helix-db/src/helixc/generator/bool_ops.rs +++ b/helix-db/src/helixc/generator/bool_ops.rs @@ -1,7 +1,10 @@ use core::fmt; use std::fmt::Display; -use crate::helixc::generator::traversal_steps::{Step, Traversal, TraversalType}; +use crate::helixc::generator::{ + source_steps::SourceStep, + traversal_steps::{Step, Traversal, TraversalType}, +}; use super::utils::{GenRef, GeneratedValue, Separator}; @@ -155,13 +158,23 @@ impl Display for BoExp { } BoExp::Exists(traversal) => { // Optimize Exists expressions in filter context to use std::iter::once for single values + println!("Optimizing Exists expression"); + println!("{:?}", traversal.traversal_type); + println!("{:?}", traversal.source_step); let is_val_traversal = match &traversal.traversal_type { TraversalType::FromIter(var) | TraversalType::FromSingle(var) => match var { - GenRef::Std(s) | GenRef::Literal(s) => s == "val", + GenRef::Std(s) | GenRef::Literal(s) => { + s == "val" + && matches!( + traversal.source_step.inner(), + SourceStep::Identifier(_) | SourceStep::Anonymous + ) + } _ => false, }, _ => false, }; + println!("is_val_traversal: {}", is_val_traversal); if is_val_traversal { // Create a modified traversal that uses FromSingle instead of FromIter diff --git a/hql-tests/test.sh b/hql-tests/test.sh index 1f4d150cd..588a41d2e 100644 --- a/hql-tests/test.sh +++ b/hql-tests/test.sh @@ -11,8 +11,8 @@ fi file_name=$1 -helix compile --path "/Users/xav/GitHub/helix-db-core/hql-tests/tests/$file_name" --output "/Users/xav/GitHub/helix-db-core/helix-container/src" -output=$(cargo check --manifest-path "/Users/xav/GitHub/helix-db-core/helix-container/Cargo.toml") +helix compile --path "/Users/xav/GitHub/helix-db/hql-tests/tests/$file_name" --output "/Users/xav/GitHub/helix-db/helix-container/src" +output=$(cargo check --manifest-path "/Users/xav/GitHub/helix-db/helix-container/Cargo.toml") if [ $? -ne 0 ]; then echo "Error: Cargo check failed" echo "Cargo check output: $output" diff --git a/hql-tests/tests/cloud_queries_2/config.hx.json b/hql-tests/tests/cloud_queries_2/config.hx.json new file mode 100644 index 000000000..117795e49 --- /dev/null +++ b/hql-tests/tests/cloud_queries_2/config.hx.json @@ -0,0 +1,14 @@ +{ + "vector_config": { + "m": 16, + "ef_construction": 128, + "ef_search": 768, + "db_max_size": 20 + }, + "graph_config": { + "secondary_indices": [] + }, + "db_max_size_gb": 20, + "mcp": true, + "bm25": true +} diff --git a/hql-tests/tests/cloud_queries_2/helix.toml b/hql-tests/tests/cloud_queries_2/helix.toml new file mode 100644 index 000000000..321cf1fff --- /dev/null +++ b/hql-tests/tests/cloud_queries_2/helix.toml @@ -0,0 +1,9 @@ +[project] +name = "cloud_queries_2" +queries = "." + +[local.dev] +port = 6969 +build_mode = "debug" + +[cloud] diff --git a/hql-tests/tests/cloud_queries_2/queries.hx b/hql-tests/tests/cloud_queries_2/queries.hx new file mode 100644 index 000000000..ae60266bb --- /dev/null +++ b/hql-tests/tests/cloud_queries_2/queries.hx @@ -0,0 +1,5 @@ + + +QUERY ExistsUserByGithubId(github_id: U64) => + user_exists <- EXISTS(N({ github_id: github_id })) + RETURN user_exists diff --git a/hql-tests/tests/cloud_queries_2/schema.hx b/hql-tests/tests/cloud_queries_2/schema.hx new file mode 100644 index 000000000..8b53fea99 --- /dev/null +++ b/hql-tests/tests/cloud_queries_2/schema.hx @@ -0,0 +1,46 @@ +N::User { + INDEX github_id: U64, + github_login: String, + github_name: String DEFAULT "", + github_email: String DEFAULT "", + created_at: Date DEFAULT NOW, + updated_at: Date DEFAULT NOW, +} + +N::Cluster { + INDEX railway_project_id: String, + project_name: String, + railway_region: String DEFAULT "us-east4-eqdc4a", + db_url: String DEFAULT "", + created_at: Date DEFAULT NOW, + updated_at: Date DEFAULT NOW, +} + +N::Instance { + railway_service_id: String, + railway_environment_id: String, + instance_type: String, + storage_gb: U64, + ram_gb: U64, + created_at: Date DEFAULT NOW, + updated_at: Date DEFAULT NOW, +} + +E::CreatedCluster { + From: User, + To: Cluster, +} + +E::HasInstance { + From: Cluster, + To: Instance, +} + +N::ApiKey { + unkey_key_id: String, +} + +E::CreatedApiKey { + From: User, + To: ApiKey, +} \ No newline at end of file From 2ed62b2d4ad670d933abeeaff8aed0b4d2628e67 Mon Sep 17 00:00:00 2001 From: xav-db Date: Wed, 19 Nov 2025 14:40:19 -0800 Subject: [PATCH 11/48] Invalid variable name not being errored by ::To ::From when adding edge BUG-1 --- .../analyzer/methods/infer_expr_type.rs | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs index 312b531cf..3dd8182da 100644 --- a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs +++ b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs @@ -588,6 +588,16 @@ pub(crate) fn infer_expr_type<'a>( Some(id) => match id { IdType::Identifier { value, loc } => { is_valid_identifier(ctx, original_query, loc.clone(), value.as_str()); + // Validate that the identifier exists in scope or is a parameter + if !scope.contains_key(value.as_str()) && is_param(original_query, value.as_str()).is_none() { + generate_error!( + ctx, + original_query, + loc.clone(), + E301, + value.as_str() + ); + } // Check if this variable is plural let is_plural = scope .get(value.as_str()) @@ -626,6 +636,16 @@ pub(crate) fn infer_expr_type<'a>( Some(id) => match id { IdType::Identifier { value, loc } => { is_valid_identifier(ctx, original_query, loc.clone(), value.as_str()); + // Validate that the identifier exists in scope or is a parameter + if !scope.contains_key(value.as_str()) && is_param(original_query, value.as_str()).is_none() { + generate_error!( + ctx, + original_query, + loc.clone(), + E301, + value.as_str() + ); + } // Check if this variable is plural let is_plural = scope .get(value.as_str()) From 66f8e5ba97e21379699aa085006799a282582430 Mon Sep 17 00:00:00 2001 From: xav-db Date: Wed, 19 Nov 2025 14:55:25 -0800 Subject: [PATCH 12/48] fixing panics when fields don't match schema when adding --- .../analyzer/methods/infer_expr_type.rs | 606 ++++++++++++------ 1 file changed, 395 insertions(+), 211 deletions(-) diff --git a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs index 3dd8182da..278c20be7 100644 --- a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs +++ b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs @@ -212,57 +212,68 @@ pub(crate) fn infer_expr_type<'a>( value.as_str() ); } else { - - let variable_type = - &scope.get(value.as_str()).unwrap().ty; - if variable_type - != &Type::from( - field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(), - ) - { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - value.as_str(), - &variable_type.to_string(), - &field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .to_string(), - "node", - ty.as_str() - ); + // Variable is in scope, now validate field type + match scope.get(value.as_str()) { + Some(var_info) => { + match field_set.get(field_name.as_str()) { + Some(field) => { + let variable_type = &var_info.ty; + if variable_type + != &Type::from( + field.field_type.clone(), + ) + { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + value.as_str(), + &variable_type.to_string(), + &field.field_type.to_string(), + "node", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + // Skip further validation to prevent panic + } + } + } + None => { + // Variable not in scope - error already generated above + } } } } ValueType::Literal { value, loc } => { - let field_type = ctx - .node_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(); - if field_type != *value { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - &value.inner_stringify(), - value.to_variant_string(), - &field_type.to_string(), - "node", - ty.as_str() - ); + match ctx.node_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) { + Some(field) => { + let field_type = field.field_type.clone(); + if field_type != *value { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + &value.inner_stringify(), + value.to_variant_string(), + &field_type.to_string(), + "node", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + } + }, + None => { + // Type doesn't exist - error already generated above + } } } _ => {} @@ -276,33 +287,42 @@ pub(crate) fn infer_expr_type<'a>( field_name.clone(), match value { ValueType::Literal { value, loc } => { - match ctx - .node_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - == FieldType::Date - { - true => match Date::new(value) { - Ok(date) => GeneratedValue::Literal( - GenRef::Literal(date.to_rfc3339()), - ), - Err(_) => { - generate_error!( - ctx, - original_query, - loc.clone(), - E501, - value.as_str() - ); + match ctx.node_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) + { + Some(field) => { + match field.field_type == FieldType::Date { + true => match Date::new(value) { + Ok(date) => GeneratedValue::Literal( + GenRef::Literal( + date.to_rfc3339(), + ), + ), + Err(_) => { + generate_error!( + ctx, + original_query, + loc.clone(), + E501, + value.as_str() + ); + GeneratedValue::Unknown + } + }, + false => GeneratedValue::Literal( + GenRef::from(value.clone()), + ), + } + } + None => { + // Field doesn't exist - error already generated GeneratedValue::Unknown } }, - false => GeneratedValue::Literal(GenRef::from( - value.clone(), - )), + None => { + // Type doesn't exist - error already generated + GeneratedValue::Unknown + } } } ValueType::Identifier { value, .. } => { @@ -446,57 +466,68 @@ pub(crate) fn infer_expr_type<'a>( value.as_str() ); } else { - let variable_type = - &scope.get(value.as_str()).unwrap().ty; - if variable_type - != &Type::from( - field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(), - ) - { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - value.as_str(), - &variable_type.to_string(), - &field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .to_string(), - "edge", - ty.as_str() - ); + // Variable is in scope, now validate field type + match scope.get(value.as_str()) { + Some(var_info) => { + match field_set.get(field_name.as_str()) { + Some(field) => { + let variable_type = &var_info.ty; + if variable_type + != &Type::from( + field.field_type.clone(), + ) + { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + value.as_str(), + &variable_type.to_string(), + &field.field_type.to_string(), + "edge", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + } + } + } + None => { + // Variable not in scope - error already generated above + } } } } ValueType::Literal { value, loc } => { // check against type - let field_type = ctx - .edge_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(); - if field_type != *value { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - &value.inner_stringify(), - value.to_variant_string(), - &field_type.to_string(), - "edge", - ty.as_str() - ); + match ctx.edge_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) { + Some(field) => { + let field_type = field.field_type.clone(); + if field_type != *value { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + &value.inner_stringify(), + value.to_variant_string(), + &field_type.to_string(), + "edge", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + } + }, + None => { + // Type doesn't exist - error already generated above + } } } _ => {} @@ -510,33 +541,42 @@ pub(crate) fn infer_expr_type<'a>( field_name.clone(), match value { ValueType::Literal { value, loc } => { - match ctx - .edge_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - == FieldType::Date - { - true => match Date::new(value) { - Ok(date) => GeneratedValue::Literal( - GenRef::Literal(date.to_rfc3339()), - ), - Err(_) => { - generate_error!( - ctx, - original_query, - loc.clone(), - E501, - value.as_str() - ); + match ctx.edge_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) + { + Some(field) => { + match field.field_type == FieldType::Date { + true => match Date::new(value) { + Ok(date) => GeneratedValue::Literal( + GenRef::Literal( + date.to_rfc3339(), + ), + ), + Err(_) => { + generate_error!( + ctx, + original_query, + loc.clone(), + E501, + value.as_str() + ); + GeneratedValue::Unknown + } + }, + false => GeneratedValue::Literal( + GenRef::from(value.clone()), + ), + } + } + None => { + // Field doesn't exist - error already generated GeneratedValue::Unknown } }, - false => GeneratedValue::Literal(GenRef::from( - value.clone(), - )), + None => { + // Type doesn't exist - error already generated + GeneratedValue::Unknown + } } } ValueType::Identifier { value, loc } => { @@ -781,57 +821,68 @@ pub(crate) fn infer_expr_type<'a>( value.as_str() ); } else { - let variable_type = - &scope.get(value.as_str()).unwrap().ty; - if variable_type - != &Type::from( - field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(), - ) - { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - value.as_str(), - &variable_type.to_string(), - &field_set - .get(field_name.as_str()) - .unwrap() - .field_type - .to_string(), - "vector", - ty.as_str() - ); + // Variable is in scope, now validate field type + match scope.get(value.as_str()) { + Some(var_info) => { + match field_set.get(field_name.as_str()) { + Some(field) => { + let variable_type = &var_info.ty; + if variable_type + != &Type::from( + field.field_type.clone(), + ) + { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + value.as_str(), + &variable_type.to_string(), + &field.field_type.to_string(), + "vector", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + } + } + } + None => { + // Variable not in scope - error already generated above + } } } } ValueType::Literal { value, loc } => { // check against type - let field_type = ctx - .vector_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - .clone(); - if field_type != *value { - generate_error!( - ctx, - original_query, - loc.clone(), - E205, - value.as_str(), - &value.to_variant_string(), - &field_type.to_string(), - "vector", - ty.as_str() - ); + match ctx.vector_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) { + Some(field) => { + let field_type = field.field_type.clone(); + if field_type != *value { + generate_error!( + ctx, + original_query, + loc.clone(), + E205, + value.as_str(), + &value.to_variant_string(), + &field_type.to_string(), + "vector", + ty.as_str() + ); + } + } + None => { + // Field doesn't exist - error already generated above + } + }, + None => { + // Type doesn't exist - error already generated above + } } } _ => {} @@ -845,33 +896,42 @@ pub(crate) fn infer_expr_type<'a>( field_name.clone(), match value { ValueType::Literal { value, loc } => { - match ctx - .vector_fields - .get(ty.as_str()) - .unwrap() - .get(field_name.as_str()) - .unwrap() - .field_type - == FieldType::Date - { - true => match Date::new(value) { - Ok(date) => GeneratedValue::Literal( - GenRef::Literal(date.to_rfc3339()), - ), - Err(_) => { - generate_error!( - ctx, - original_query, - loc.clone(), - E501, - value.as_str() - ); + match ctx.vector_fields.get(ty.as_str()) { + Some(fields) => match fields.get(field_name.as_str()) + { + Some(field) => { + match field.field_type == FieldType::Date { + true => match Date::new(value) { + Ok(date) => GeneratedValue::Literal( + GenRef::Literal( + date.to_rfc3339(), + ), + ), + Err(_) => { + generate_error!( + ctx, + original_query, + loc.clone(), + E501, + value.as_str() + ); + GeneratedValue::Unknown + } + }, + false => GeneratedValue::Literal( + GenRef::from(value.clone()), + ), + } + } + None => { + // Field doesn't exist - error already generated GeneratedValue::Unknown } }, - false => GeneratedValue::Literal(GenRef::from( - value.clone(), - )), + None => { + // Type doesn't exist - error already generated + GeneratedValue::Unknown + } } } ValueType::Identifier { value, loc } => { @@ -1661,4 +1721,128 @@ mod tests { let (diagnostics, _) = result.unwrap(); assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E205)); } + + // ============================================================================ + // Invalid Field Name Tests (E202) + // ============================================================================ + + #[test] + fn test_add_node_invalid_field_name() { + let source = r#" + N::Person { name: String, age: U32 } + + QUERY test() => + person <- AddN({name: "Alice", invalidField: 42}) + RETURN person + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } + + #[test] + fn test_add_node_invalid_field_with_identifier() { + let source = r#" + N::Person { name: String, age: U32 } + + QUERY test(value: U32) => + person <- AddN({name: "Alice", wrongField: value}) + RETURN person + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } + + #[test] + fn test_add_edge_invalid_field_name() { + let source = r#" + N::Person { name: String } + E::Knows { From: Person, To: Person, Properties: { since: U32 } } + + QUERY test(id1: ID, id2: ID) => + person1 <- N(id1) + person2 <- N(id2) + edge <- AddE({since: 2020, badField: 123})::From(person1)::To(person2) + RETURN edge + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } + + #[test] + fn test_add_edge_invalid_field_with_identifier() { + let source = r#" + N::Person { name: String } + E::Knows { From: Person, To: Person, Properties: { since: U32 } } + + QUERY test(id1: ID, id2: ID, year: U32) => + person1 <- N(id1) + person2 <- N(id2) + edge <- AddE({invalidField: year})::From(person1)::To(person2) + RETURN edge + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } + + #[test] + fn test_add_vector_invalid_field_name() { + let source = r#" + V::Document { content: String, embedding: [F32] } + + QUERY test(vec: [F32]) => + doc <- AddV(vec, {content: "test", wrongField: "bad"}) + RETURN doc + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } + + #[test] + fn test_add_vector_invalid_field_with_identifier() { + let source = r#" + V::Document { content: String, embedding: [F32] } + + QUERY test(vec: [F32], text: String) => + doc <- AddV(vec, {content: text, badField: "invalid"}) + RETURN doc + "#; + + let content = write_to_temp_file(vec![source]); + let parsed = HelixParser::parse_source(&content).unwrap(); + let result = crate::helixc::analyzer::analyze(&parsed); + + assert!(result.is_ok()); + let (diagnostics, _) = result.unwrap(); + assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E202)); + } } From 2a7af84314cce50ef66da6f26ab903094f9d2928 Mon Sep 17 00:00:00 2001 From: xav-db Date: Wed, 19 Nov 2025 14:55:33 -0800 Subject: [PATCH 13/48] removing prints --- helix-db/src/helixc/generator/bool_ops.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/helix-db/src/helixc/generator/bool_ops.rs b/helix-db/src/helixc/generator/bool_ops.rs index 62f84370f..e3cd7dce3 100644 --- a/helix-db/src/helixc/generator/bool_ops.rs +++ b/helix-db/src/helixc/generator/bool_ops.rs @@ -158,9 +158,6 @@ impl Display for BoExp { } BoExp::Exists(traversal) => { // Optimize Exists expressions in filter context to use std::iter::once for single values - println!("Optimizing Exists expression"); - println!("{:?}", traversal.traversal_type); - println!("{:?}", traversal.source_step); let is_val_traversal = match &traversal.traversal_type { TraversalType::FromIter(var) | TraversalType::FromSingle(var) => match var { GenRef::Std(s) | GenRef::Literal(s) => { @@ -174,7 +171,6 @@ impl Display for BoExp { }, _ => false, }; - println!("is_val_traversal: {}", is_val_traversal); if is_val_traversal { // Create a modified traversal that uses FromSingle instead of FromIter From 1f65ba1974c01da1dc495d43b0cd2f732384518d Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 21 Nov 2025 11:07:56 -0800 Subject: [PATCH 14/48] fixing clippy --- helix-cli/src/cleanup.rs | 6 ++++++ helix-cli/src/commands/add.rs | 5 ++--- helix-cli/src/commands/init.rs | 5 ++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/helix-cli/src/cleanup.rs b/helix-cli/src/cleanup.rs index a7d746f1f..7c6ef7a22 100644 --- a/helix-cli/src/cleanup.rs +++ b/helix-cli/src/cleanup.rs @@ -25,6 +25,12 @@ pub struct CleanupSummary { pub errors: Vec, } +impl Default for CleanupTracker { + fn default() -> Self { + Self::new() + } +} + impl CleanupTracker { /// Create a new cleanup tracker pub fn new() -> Self { diff --git a/helix-cli/src/commands/add.rs b/helix-cli/src/commands/add.rs index ad13546c6..f27c84065 100644 --- a/helix-cli/src/commands/add.rs +++ b/helix-cli/src/commands/add.rs @@ -18,13 +18,12 @@ pub async fn run(deployment_type: CloudDeploymentTypeCommand) -> Result<()> { let result = run_add_inner(deployment_type, &mut cleanup_tracker).await; // If there was an error, perform cleanup - if let Err(ref e) = result { - if cleanup_tracker.has_tracked_resources() { + if let Err(ref e) = result + && cleanup_tracker.has_tracked_resources() { eprintln!("Add failed, performing cleanup: {}", e); let summary = cleanup_tracker.cleanup(); summary.log_summary(); } - } result } diff --git a/helix-cli/src/commands/init.rs b/helix-cli/src/commands/init.rs index 9208e88cb..76b283afa 100644 --- a/helix-cli/src/commands/init.rs +++ b/helix-cli/src/commands/init.rs @@ -32,13 +32,12 @@ pub async fn run( .await; // If there was an error, perform cleanup - if let Err(ref e) = result { - if cleanup_tracker.has_tracked_resources() { + if let Err(ref e) = result + && cleanup_tracker.has_tracked_resources() { eprintln!("Init failed, performing cleanup: {}", e); let summary = cleanup_tracker.cleanup(); summary.log_summary(); } - } result } From 39af3d10496c42e1230db9614f84456ad71a1867 Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 21 Nov 2025 11:12:12 -0800 Subject: [PATCH 15/48] updating versions --- Cargo.lock | 4 ++-- helix-cli/Cargo.toml | 2 +- helix-db/Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 22f353c27..96e0d93c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1314,7 +1314,7 @@ dependencies = [ [[package]] name = "helix-cli" -version = "2.1.2" +version = "2.1.3" dependencies = [ "async-trait", "chrono", @@ -1365,7 +1365,7 @@ dependencies = [ [[package]] name = "helix-db" -version = "1.1.2" +version = "1.1.3" dependencies = [ "async-trait", "axum", diff --git a/helix-cli/Cargo.toml b/helix-cli/Cargo.toml index be42d5199..786773bcf 100644 --- a/helix-cli/Cargo.toml +++ b/helix-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "helix-cli" -version = "2.1.2" +version = "2.1.3" edition = "2024" [dependencies] diff --git a/helix-db/Cargo.toml b/helix-db/Cargo.toml index e9a85b013..4c108acf7 100644 --- a/helix-db/Cargo.toml +++ b/helix-db/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "helix-db" -version = "1.1.2" +version = "1.1.3" edition = "2024" description = "HelixDB is a powerful, open-source, graph-vector database built in Rust for intelligent data storage for RAG and AI." license = "AGPL-3.0" From f7b046a651d2951a963657ac4032d1ced23a415d Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 21 Nov 2025 12:08:11 -0800 Subject: [PATCH 16/48] add s3 push --- .github/workflows/s3_push.yml | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 .github/workflows/s3_push.yml diff --git a/.github/workflows/s3_push.yml b/.github/workflows/s3_push.yml new file mode 100644 index 000000000..22f818d22 --- /dev/null +++ b/.github/workflows/s3_push.yml @@ -0,0 +1,46 @@ +name: Push to S3 + +on: + release: + types: [published, created] + create: + tags: + - 'v*' + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + upload-to-s3: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Configure AWS credentials using OIDC + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::${{ vars.AWS_ACCOUNT_ID }}:role/GitHubActionsS3Role + aws-region: us-east-1 + + - name: Upload specified files and directories to S3 + run: | + # Sync directories + aws s3 sync helix-cli/ s3://helix-repo/template/helix-cli/ --exclude "target/*" + aws s3 sync helix-container/ s3://helix-repo/template/helix-container/ --exclude "target/*" + aws s3 sync helix-macros/ s3://helix-repo/template/helix-macros/ --exclude "target/*" + aws s3 sync metrics/ s3://helix-repo/template/metrics/ --exclude "target/*" + + # Upload root-level Cargo files + aws s3 cp Cargo.lock s3://helix-repo/template/Cargo.lock + aws s3 cp Cargo.toml s3://helix-repo/template/Cargo.toml + + - name: Upload completion notification + if: success() + run: | + echo "Successfully uploaded all files to S3 bucket: helix-repo" + echo "Upload triggered by: ${{ github.event_name }}" + echo "Reference: ${{ github.ref }}" \ No newline at end of file From 95b574c599fd2ae9fedb4113297c193def40e2ed Mon Sep 17 00:00:00 2001 From: xav-db Date: Fri, 21 Nov 2025 20:51:03 -0800 Subject: [PATCH 17/48] fixing issue with mcp brute force v search --- helix-db/src/helix_gateway/mcp/tools.rs | 51 +++++++++++++++++-------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/helix-db/src/helix_gateway/mcp/tools.rs b/helix-db/src/helix_gateway/mcp/tools.rs index 590ec13f8..d130a685c 100644 --- a/helix-db/src/helix_gateway/mcp/tools.rs +++ b/helix-db/src/helix_gateway/mcp/tools.rs @@ -78,6 +78,7 @@ pub enum ToolArgs { vector: Vec, k: usize, min_score: Option, + cutoff: Option, }, } @@ -380,30 +381,50 @@ where // SearchVecText requires embedding model initialization // It should be called via the dedicated search_vec_text MCP handler // not through the generic query chain execution - Err(GraphError::New( - format!("SearchVecText (query: {}, label: {}, k: {}) is not supported in generic query chains. Use the search_vec_text endpoint directly.", query, label, k) - )) + Err(GraphError::New(format!( + "SearchVecText (query: {}, label: {}, k: {}) is not supported in generic query chains. Use the search_vec_text endpoint directly.", + query, label, k + ))) } - ToolArgs::SearchVec { vector, k, min_score } => { + ToolArgs::SearchVec { + vector, + k, + min_score, + cutoff, + } => { use crate::helix_engine::traversal_core::ops::vectors::brute_force_search::BruteForceSearchVAdapter; let query_vec = arena.alloc_slice_copy(vector); - let mut results = stream.map(|iter| iter.range(0, *k*3).brute_force_search_v(query_vec, *k)); + let mut results = match cutoff { + Some(cutoff_val) => stream.map(|iter| { + iter.range(0, *cutoff_val) + .brute_force_search_v(query_vec, *k) + }), + None => stream.map(|iter| iter.brute_force_search_v(query_vec, *k)), + }; // Apply min_score filter if specified if let Some(min_score_val) = min_score { let min_score_copy = *min_score_val; results = results.map(|iter| { - let RoTraversalIterator { storage, arena, txn, inner } = iter; - let filtered: DynIter<'arena, 'txn> = Box::new( - inner.filter(move |item_res| { - match item_res { - Ok(TraversalValue::Vector(v)) => v.get_distance() > min_score_copy, - _ => true, // Keep non-vector items - } - }) - ); - RoTraversalIterator { storage, arena, txn, inner: filtered } + let RoTraversalIterator { + storage, + arena, + txn, + inner, + } = iter; + let filtered: DynIter<'arena, 'txn> = Box::new(inner.filter(move |item_res| { + match item_res { + Ok(TraversalValue::Vector(v)) => v.get_distance() > min_score_copy, + _ => true, // Keep non-vector items + } + })); + RoTraversalIterator { + storage, + arena, + txn, + inner: filtered, + } }); } From 124313ac925ee4f1c10928b07c1d65f22933b667 Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 12:37:21 -0800 Subject: [PATCH 18/48] fixing issue with seaerch_vector --- helix-db/src/helix_gateway/mcp/mcp.rs | 184 ++++++++++++++++++-------- 1 file changed, 130 insertions(+), 54 deletions(-) diff --git a/helix-db/src/helix_gateway/mcp/mcp.rs b/helix-db/src/helix_gateway/mcp/mcp.rs index a62b91a15..239a8c18a 100644 --- a/helix-db/src/helix_gateway/mcp/mcp.rs +++ b/helix-db/src/helix_gateway/mcp/mcp.rs @@ -7,7 +7,7 @@ use crate::{ }, types::GraphError, }, - helix_gateway::mcp::tools::{execute_query_chain, EdgeType, FilterTraversal, Order, ToolArgs}, + helix_gateway::mcp::tools::{EdgeType, FilterTraversal, Order, ToolArgs, execute_query_chain}, protocol::{Format, Request, Response}, utils::id::v6_uuid, }; @@ -169,36 +169,52 @@ fn execute_tool_step( connection_id: &str, tool: ToolArgs, ) -> Result { - tracing::debug!("[EXECUTE_TOOL_STEP] Starting with connection_id: {}", connection_id); + tracing::debug!( + "[EXECUTE_TOOL_STEP] Starting with connection_id: {}", + connection_id + ); // Clone necessary data while holding the lock let query_chain = { tracing::debug!("[EXECUTE_TOOL_STEP] Acquiring connection lock"); let mut connections = input.mcp_connections.lock().unwrap(); - tracing::debug!("[EXECUTE_TOOL_STEP] Available connections: {:?}", - connections.connections.keys().collect::>()); + tracing::debug!( + "[EXECUTE_TOOL_STEP] Available connections: {:?}", + connections.connections.keys().collect::>() + ); let connection = connections .get_connection_mut(connection_id) .ok_or_else(|| { - tracing::error!("[EXECUTE_TOOL_STEP] Connection not found: {}", connection_id); + tracing::error!( + "[EXECUTE_TOOL_STEP] Connection not found: {}", + connection_id + ); GraphError::StorageError(format!("Connection not found: {}", connection_id)) })?; - tracing::debug!("[EXECUTE_TOOL_STEP] Adding query step, current chain length: {}", - connection.query_chain.len()); + tracing::debug!( + "[EXECUTE_TOOL_STEP] Adding query step, current chain length: {}", + connection.query_chain.len() + ); connection.add_query_step(tool); connection.query_chain.clone() }; - tracing::debug!("[EXECUTE_TOOL_STEP] Executing query chain with {} steps", query_chain.len()); + tracing::debug!( + "[EXECUTE_TOOL_STEP] Executing query chain with {} steps", + query_chain.len() + ); // Execute long-running operation without holding the lock let arena = Bump::new(); let storage = input.mcp_backend.db.as_ref(); let txn = storage.graph_env.read_txn().map_err(|e| { - tracing::error!("[EXECUTE_TOOL_STEP] Failed to create read transaction: {:?}", e); + tracing::error!( + "[EXECUTE_TOOL_STEP] Failed to create read transaction: {:?}", + e + ); e })?; @@ -210,17 +226,20 @@ fn execute_tool_step( let mut iter = stream.into_inner_iter(); let (first, consumed_one) = match iter.next() { - Some(value) => { + Some(value) => { let val = value.map_err(|e| { tracing::error!("[EXECUTE_TOOL_STEP] Error getting first value: {:?}", e); e })?; (val, true) - } + } None => (TraversalValue::Empty, false), }; - tracing::debug!("[EXECUTE_TOOL_STEP] Got first result, consumed: {}", consumed_one); + tracing::debug!( + "[EXECUTE_TOOL_STEP] Got first result, consumed: {}", + consumed_one + ); // Update connection state { @@ -228,7 +247,10 @@ fn execute_tool_step( let connection = connections .get_connection_mut(connection_id) .ok_or_else(|| { - tracing::error!("[EXECUTE_TOOL_STEP] Connection not found when updating state: {}", connection_id); + tracing::error!( + "[EXECUTE_TOOL_STEP] Connection not found when updating state: {}", + connection_id + ); GraphError::StorageError(format!("Connection not found: {}", connection_id)) })?; connection.current_position = if consumed_one { 1 } else { 0 }; @@ -283,8 +305,10 @@ pub fn next(input: &mut MCPToolInput) -> Result { // Clone necessary data while holding the lock let (query_chain, current_position) = { let connections = input.mcp_connections.lock().unwrap(); - tracing::debug!("[NEXT] Available connections: {:?}", - connections.connections.keys().collect::>()); + tracing::debug!( + "[NEXT] Available connections: {:?}", + connections.connections.keys().collect::>() + ); let connection = connections .get_connection(&data.connection_id) @@ -295,7 +319,11 @@ pub fn next(input: &mut MCPToolInput) -> Result { (connection.query_chain.clone(), connection.current_position) }; - tracing::debug!("[NEXT] Current position: {}, chain length: {}", current_position, query_chain.len()); + tracing::debug!( + "[NEXT] Current position: {}, chain length: {}", + current_position, + query_chain.len() + ); // Execute long-running operation without holding the lock let arena = Bump::new(); @@ -311,7 +339,11 @@ pub fn next(input: &mut MCPToolInput) -> Result { })?; let next_value = match stream.nth(current_position).map_err(|e| { - tracing::error!("[NEXT] Error iterating to position {}: {:?}", current_position, e); + tracing::error!( + "[NEXT] Error iterating to position {}: {:?}", + current_position, + e + ); e })? { Some(value) => { @@ -320,11 +352,20 @@ pub fn next(input: &mut MCPToolInput) -> Result { let connection = connections .get_connection_mut(&data.connection_id) .ok_or_else(|| { - tracing::error!("[NEXT] Connection not found when updating position: {}", data.connection_id); - GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + tracing::error!( + "[NEXT] Connection not found when updating position: {}", + data.connection_id + ); + GraphError::StorageError(format!( + "Connection not found: {}", + data.connection_id + )) })?; connection.current_position += 1; - tracing::debug!("[NEXT] Updated position to: {}", connection.current_position); + tracing::debug!( + "[NEXT] Updated position to: {}", + connection.current_position + ); value } None => { @@ -361,7 +402,9 @@ pub fn collect(input: &mut MCPToolInput) -> Result { let connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; connection.query_chain.clone() }; @@ -381,9 +424,10 @@ pub fn collect(input: &mut MCPToolInput) -> Result { let item = item?; if index >= start { if let Some(end) = end - && index >= end { - break; - } + && index >= end + { + break; + } values.push(item); } } @@ -393,7 +437,9 @@ pub fn collect(input: &mut MCPToolInput) -> Result { let mut connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection_mut(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; if data.drop.unwrap_or(true) { connection.clear_chain(); @@ -422,7 +468,9 @@ pub fn aggregate_by(input: &mut MCPToolInput) -> Result { let connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; connection.query_chain.clone() }; @@ -442,7 +490,9 @@ pub fn aggregate_by(input: &mut MCPToolInput) -> Result { let mut connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection_mut(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; if data.drop.unwrap_or(true) { connection.clear_chain(); @@ -464,7 +514,9 @@ pub fn group_by(input: &mut MCPToolInput) -> Result { let connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; connection.query_chain.clone() }; @@ -484,7 +536,9 @@ pub fn group_by(input: &mut MCPToolInput) -> Result { let mut connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection_mut(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; if data.drop.unwrap_or(true) { connection.clear_chain(); @@ -509,7 +563,9 @@ pub fn reset(input: &mut MCPToolInput) -> Result { let mut connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection_mut(&data.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", data.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", data.connection_id)) + })?; connection.clear_chain(); let connection_id = connection.connection_id.clone(); @@ -770,10 +826,7 @@ pub struct SearchKeywordInput { #[mcp_handler] pub fn search_keyword(input: &mut MCPToolInput) -> Result { - use crate::helix_engine::traversal_core::ops::{ - bm25::search_bm25::SearchBM25Adapter, - g::G, - }; + use crate::helix_engine::traversal_core::ops::{bm25::search_bm25::SearchBM25Adapter, g::G}; let req: SearchKeywordInput = match sonic_rs::from_slice(&input.request.body) { Ok(data) => data, @@ -785,7 +838,9 @@ pub fn search_keyword(input: &mut MCPToolInput) -> Result let connections = input.mcp_connections.lock().unwrap(); connections .get_connection(&req.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", req.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", req.connection_id)) + })?; } // Execute long-running operation without holding the lock @@ -796,7 +851,7 @@ pub fn search_keyword(input: &mut MCPToolInput) -> Result // Perform BM25 search using the existing index let results = G::new(storage, &txn, &arena) .search_bm25(&req.data.label, &req.data.query, req.data.limit)? - .collect::,_>>()?; + .collect::, _>>()?; let (first, consumed_one) = match results.first() { Some(value) => (value.clone(), true), @@ -808,7 +863,9 @@ pub fn search_keyword(input: &mut MCPToolInput) -> Result let mut connections = input.mcp_connections.lock().unwrap(); let connection = connections .get_connection_mut(&req.connection_id) - .ok_or_else(|| GraphError::StorageError(format!("Connection not found: {}", req.connection_id)))?; + .ok_or_else(|| { + GraphError::StorageError(format!("Connection not found: {}", req.connection_id)) + })?; // Store remaining results for pagination connection.current_position = if consumed_one { 1 } else { 0 }; @@ -833,11 +890,8 @@ pub struct SearchVectorTextInput { #[mcp_handler] pub fn search_vector_text(input: &mut MCPToolInput) -> Result { - use crate::helix_engine::traversal_core::ops::{ - g::G, - vectors::search::SearchVAdapter, - }; - use crate::helix_gateway::embedding_providers::{get_embedding_model, EmbeddingModel}; + use crate::helix_engine::traversal_core::ops::{g::G, vectors::search::SearchVAdapter}; + use crate::helix_gateway::embedding_providers::{EmbeddingModel, get_embedding_model}; let req: SearchVectorTextInput = match sonic_rs::from_slice(&input.request.body) { Ok(data) => data, @@ -847,20 +901,30 @@ pub fn search_vector_text(input: &mut MCPToolInput) -> Result>()); + tracing::debug!( + "[VECTOR_SEARCH] Available connections: {:?}", + connections.connections.keys().collect::>() + ); connections .get_connection(&req.connection_id) .ok_or_else(|| { - tracing::error!("[VECTOR_SEARCH] Connection not found: {}", req.connection_id); + tracing::error!( + "[VECTOR_SEARCH] Connection not found: {}", + req.connection_id + ); GraphError::StorageError(format!("Connection not found: {}", req.connection_id)) })?; } @@ -883,17 +947,22 @@ pub fn search_vector_text(input: &mut MCPToolInput) -> Result bool, _>( query_vec_arena, @@ -923,12 +992,18 @@ pub fn search_vector_text(input: &mut MCPToolInput) -> Result Result { vector: req.data.vector, k: req.data.k, min_score: req.data.min_score, + cutoff: None, }; execute_tool_step(input, &req.connection_id, tool) From 5717828e4700df7c9f0e8a09b2029e2ac8c0ad55 Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 12:39:56 -0800 Subject: [PATCH 19/48] fixing outdated metrics info in installation --- helix-cli/install.sh | 82 ++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/helix-cli/install.sh b/helix-cli/install.sh index 6ecd112ed..fcca231aa 100755 --- a/helix-cli/install.sh +++ b/helix-cli/install.sh @@ -33,7 +33,7 @@ Helix CLI Installer USAGE: curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | bash - + # Or with options: curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | bash -s -- [OPTIONS] @@ -46,10 +46,10 @@ OPTIONS: EXAMPLES: # User install (default) curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | bash - + # System install curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | bash -s -- --system - + # Custom directory curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | bash -s -- --dir ~/bin EOF @@ -170,23 +170,23 @@ get_latest_version() { version=$(curl -fsSL "https://api.github.com/repos/$REPO/releases/latest" | \ grep '"tag_name"' | \ sed -E 's/.*"tag_name": "([^"]+)".*/\1/') - + if [[ -z "$version" ]]; then log_error "Failed to fetch latest version from GitHub API" exit 1 fi - + echo "$version" } # Get version of installed binary get_installed_version() { local binary_path="$INSTALL_DIR/$BINARY_NAME" - + if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then binary_path="$binary_path.exe" fi - + if [[ -x "$binary_path" ]]; then "$binary_path" --version 2>/dev/null | grep -oE '[0-9]+\.[0-9]+\.[0-9]+' | head -1 || echo "" else @@ -198,31 +198,31 @@ get_installed_version() { should_install() { local latest_version="$1" local installed_version - + if [[ "$FORCE_INSTALL" == true ]]; then log_info "Force install requested" return 0 fi - + installed_version=$(get_installed_version) - + if [[ -z "$installed_version" ]]; then log_info "No existing installation found" return 0 fi - + # Remove 'v' prefix for comparison latest_version="${latest_version#v}" - + log_info "Installed version: $installed_version" log_info "Latest version: $latest_version" - + if [[ "$installed_version" == "$latest_version" ]]; then log_success "Already up to date!" log_info "Use --force to reinstall" return 1 fi - + log_info "Update available: $installed_version -> $latest_version" return 0 } @@ -233,30 +233,30 @@ install_binary() { local binary_filename="$2" local download_url="https://github.com/$REPO/releases/download/$version/$binary_filename" local binary_path="$INSTALL_DIR/$BINARY_NAME" - + # Add .exe extension on Windows if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then binary_path="$binary_path.exe" fi - + log_info "Downloading: $download_url" - + # Create install directory mkdir -p "$INSTALL_DIR" - + # Download to temporary file local temp_file temp_file=$(mktemp) - + if ! curl -fsSL "$download_url" -o "$temp_file"; then log_error "Failed to download binary" rm -f "$temp_file" exit 1 fi - + # Make executable chmod +x "$temp_file" - + # Atomic move to final location if ! mv "$temp_file" "$binary_path"; then log_error "Failed to install binary to $binary_path" @@ -264,7 +264,7 @@ install_binary() { rm -f "$temp_file" exit 1 fi - + log_success "Installed to: $binary_path" } @@ -274,17 +274,17 @@ setup_path() { # System installs should already be in PATH return 0 fi - + # Only setup PATH for default user installs if [[ "$INSTALL_DIR" != "$DEFAULT_INSTALL_DIR" ]]; then log_info "Custom install directory. Add to PATH manually:" log_info " export PATH=\"$INSTALL_DIR:\$PATH\"" return 0 fi - + local shell_config="" local path_line="" - + # Determine shell config file case "$SHELL" in */bash) @@ -305,7 +305,7 @@ setup_path() { return 0 ;; esac - + # Add to shell config if not already present if [[ -f "$shell_config" ]] && ! grep -Fq "$path_line" "$shell_config"; then echo "" >> "$shell_config" @@ -319,25 +319,25 @@ setup_path() { # Verify installation verify_installation() { local binary_path="$INSTALL_DIR/$BINARY_NAME" - + if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "cygwin" ]]; then binary_path="$binary_path.exe" fi - + if [[ ! -x "$binary_path" ]]; then log_error "Installation verification failed: binary not executable" return 1 fi - + local installed_version installed_version=$(get_installed_version) - + if [[ -n "$installed_version" ]]; then log_success "Installation verified: v$installed_version" else log_warn "Could not verify version, but binary is installed" fi - + # Test basic functionality if "$binary_path" --help >/dev/null 2>&1; then log_success "Basic functionality test passed" @@ -412,35 +412,35 @@ check_docker_desktop() { main() { log_info "Helix CLI Installer" log_info "Repository: $REPO" - + parse_args "$@" set_install_dir - + # Check for required tools if ! command -v curl >/dev/null; then log_error "curl is required but not installed" exit 1 fi - + local binary_filename latest_version - + binary_filename=$(detect_platform) log_info "Platform: $binary_filename" - + latest_version=$(get_latest_version) log_info "Latest version: $latest_version" - + if ! should_install "$latest_version"; then exit 0 fi - + # Check permissions for system install if [[ "$SYSTEM_INSTALL" == true ]] && [[ ! -w "$INSTALL_DIR" ]] && [[ $EUID -ne 0 ]]; then log_error "System install requires sudo permissions" log_info "Run: curl -fsSL https://raw.githubusercontent.com/$REPO/main/install.sh | sudo bash -s -- --system" exit 1 fi - + install_binary "$latest_version" "$binary_filename" setup_path verify_installation @@ -463,7 +463,7 @@ main() { log_info "" log_info "To disable metrics, run: helix metrics --off" log_info "" - log_info "To show metrics status, run: helix metrics --status" + log_info "To show metrics status, run: helix metrics status" } -main "$@" \ No newline at end of file +main "$@" From 9b9e4c37e0b7b759115170950015df1151a730dd Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 17:30:59 -0800 Subject: [PATCH 20/48] updating workflow issue? --- .github/workflows/clippy_check.yml | 2 +- .github/workflows/dashboard_check.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/clippy_check.yml b/.github/workflows/clippy_check.yml index ef158e255..dd26d736c 100644 --- a/.github/workflows/clippy_check.yml +++ b/.github/workflows/clippy_check.yml @@ -25,7 +25,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} restore-keys: | ${{ runner.os }}-cargo- diff --git a/.github/workflows/dashboard_check.yml b/.github/workflows/dashboard_check.yml index 9fa41b36e..75245307b 100644 --- a/.github/workflows/dashboard_check.yml +++ b/.github/workflows/dashboard_check.yml @@ -25,7 +25,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} restore-keys: | ${{ runner.os }}-cargo- From c6758d311317d0886f1aebfe403e80274572b5de Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 17:33:00 -0800 Subject: [PATCH 21/48] reverting --- .github/workflows/clippy_check.yml | 2 +- .github/workflows/dashboard_check.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/clippy_check.yml b/.github/workflows/clippy_check.yml index dd26d736c..ef158e255 100644 --- a/.github/workflows/clippy_check.yml +++ b/.github/workflows/clippy_check.yml @@ -25,7 +25,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: | ${{ runner.os }}-cargo- diff --git a/.github/workflows/dashboard_check.yml b/.github/workflows/dashboard_check.yml index 75245307b..9fa41b36e 100644 --- a/.github/workflows/dashboard_check.yml +++ b/.github/workflows/dashboard_check.yml @@ -25,7 +25,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} restore-keys: | ${{ runner.os }}-cargo- From dba705ae09a1d49d55f55589378e1da7bf90bb86 Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 17:45:50 -0800 Subject: [PATCH 22/48] fixing weird problem with workflows --- .github/workflows/clippy_check.yml | 3 ++- .github/workflows/dashboard_check.yml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/clippy_check.yml b/.github/workflows/clippy_check.yml index ef158e255..afd281ee9 100644 --- a/.github/workflows/clippy_check.yml +++ b/.github/workflows/clippy_check.yml @@ -20,12 +20,13 @@ jobs: - name: Cache cargo dependencies uses: actions/cache@v4 + continue-on-error: true with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') || 'fallback' }} restore-keys: | ${{ runner.os }}-cargo- diff --git a/.github/workflows/dashboard_check.yml b/.github/workflows/dashboard_check.yml index 9fa41b36e..6d83a6700 100644 --- a/.github/workflows/dashboard_check.yml +++ b/.github/workflows/dashboard_check.yml @@ -20,12 +20,13 @@ jobs: - name: Cache cargo dependencies uses: actions/cache@v4 + continue-on-error: true with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') || 'fallback' }} restore-keys: | ${{ runner.os }}-cargo- From 48a55cec59985c63eedacf9b6c4e8d599d198769 Mon Sep 17 00:00:00 2001 From: xav-db Date: Sat, 22 Nov 2025 19:10:03 -0800 Subject: [PATCH 23/48] fixing other tests --- .github/workflows/db_tests.yml | 3 ++- .github/workflows/hql_tests.yml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/db_tests.yml b/.github/workflows/db_tests.yml index de7ea16cd..67ce6ee0a 100644 --- a/.github/workflows/db_tests.yml +++ b/.github/workflows/db_tests.yml @@ -22,12 +22,13 @@ jobs: - name: Cache cargo dependencies uses: actions/cache@v4 + continue-on-error: true with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') || 'fallback' }} restore-keys: | ${{ runner.os }}-cargo- diff --git a/.github/workflows/hql_tests.yml b/.github/workflows/hql_tests.yml index cfd72116f..a922719e5 100644 --- a/.github/workflows/hql_tests.yml +++ b/.github/workflows/hql_tests.yml @@ -29,12 +29,13 @@ jobs: - name: Cache cargo registry uses: actions/cache@v3 + continue-on-error: true with: path: | ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') || 'fallback' }} restore-keys: | ${{ runner.os }}-cargo- From 34ba409e9543c55cdf43b9737c3411be7a283774 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Tue, 18 Nov 2025 23:07:18 -0300 Subject: [PATCH 24/48] First pass on the vector_core rewrite --- Cargo.lock | 77 +- helix-db/Cargo.toml | 11 +- helix-db/benches/hnsw_benches.rs | 27 +- helix-db/src/helix_engine/bm25/bm25.rs | 27 +- helix-db/src/helix_engine/bm25/bm25_tests.rs | 21 +- .../src/helix_engine/reranker/adapters/mod.rs | 32 +- .../src/helix_engine/reranker/fusion/mmr.rs | 54 +- .../src/helix_engine/reranker/fusion/rrf.rs | 65 +- .../reranker/models/cross_encoder.rs | 16 +- .../storage_core/graph_visualization.rs | 39 +- helix-db/src/helix_engine/storage_core/mod.rs | 12 +- .../storage_core/storage_migration.rs | 473 -------- .../storage_core/storage_migration_tests.rs | 1037 ----------------- .../hnsw_concurrent_tests.rs | 113 +- helix-db/src/helix_engine/tests/hnsw_tests.rs | 15 +- .../tests/traversal_tests/drop_tests.rs | 17 +- .../traversal_tests/edge_traversal_tests.rs | 80 +- .../tests/traversal_tests/util_tests.rs | 165 ++- .../traversal_tests/vector_traversal_tests.rs | 269 +---- .../src/helix_engine/tests/vector_tests.rs | 11 +- .../traversal_core/ops/in_/in_.rs | 2 +- .../traversal_core/ops/in_/to_v.rs | 4 +- .../traversal_core/ops/out/from_v.rs | 2 +- .../traversal_core/ops/out/out.rs | 2 +- .../traversal_core/ops/source/v_from_id.rs | 2 +- .../traversal_core/ops/source/v_from_type.rs | 68 +- .../traversal_core/ops/util/drop.rs | 6 - .../ops/vectors/brute_force_search.rs | 18 +- .../traversal_core/ops/vectors/insert.rs | 18 +- .../traversal_core/ops/vectors/search.rs | 4 +- .../traversal_core/traversal_value.rs | 29 +- helix-db/src/helix_engine/types.rs | 83 +- .../helix_engine/vector_core/binary_heap.rs | 567 --------- .../vector_core/distance/cosine.rs | 65 ++ .../helix_engine/vector_core/distance/mod.rs | 42 + helix-db/src/helix_engine/vector_core/hnsw.rs | 621 +++++++++- .../src/helix_engine/vector_core/item_iter.rs | 56 + helix-db/src/helix_engine/vector_core/key.rs | 174 +++ .../src/helix_engine/vector_core/metadata.rs | 75 ++ helix-db/src/helix_engine/vector_core/mod.rs | 328 +++++- helix-db/src/helix_engine/vector_core/node.rs | 286 +++++ .../src/helix_engine/vector_core/node_id.rs | 160 +++ .../helix_engine/vector_core/ordered_float.rs | 47 + .../src/helix_engine/vector_core/parallel.rs | 172 +++ .../src/helix_engine/vector_core/reader.rs | 754 ++++++++++++ .../helix_engine/vector_core/spaces/mod.rs | 10 + .../helix_engine/vector_core/spaces/simple.rs | 84 ++ .../vector_core/spaces/simple_avx.rs | 163 +++ .../vector_core/spaces/simple_neon.rs | 154 +++ .../vector_core/spaces/simple_sse.rs | 158 +++ .../src/helix_engine/vector_core/stats.rs | 84 ++ .../vector_core/unaligned_vector/f32.rs | 70 ++ .../vector_core/unaligned_vector/mod.rs | 182 +++ .../src/helix_engine/vector_core/utils.rs | 167 --- .../src/helix_engine/vector_core/vector.rs | 305 ----- .../helix_engine/vector_core/vector_core.rs | 664 ----------- .../vector_core/vector_distance.rs | 157 --- .../vector_core/vector_without_data.rs | 153 --- .../src/helix_engine/vector_core/version.rs | 90 ++ .../src/helix_engine/vector_core/writer.rs | 430 +++++++ helix-db/src/helix_gateway/mcp/mcp.rs | 6 +- .../custom_serde/compatibility_tests.rs | 74 +- .../protocol/custom_serde/edge_case_tests.rs | 64 +- .../custom_serde/error_handling_tests.rs | 21 +- .../custom_serde/integration_tests.rs | 52 +- .../custom_serde/property_based_tests.rs | 17 +- .../src/protocol/custom_serde/test_utils.rs | 84 +- .../src/protocol/custom_serde/vector_serde.rs | 12 +- .../custom_serde/vector_serde_tests.rs | 161 +-- helix-db/src/utils/properties.rs | 2 +- 70 files changed, 4970 insertions(+), 4540 deletions(-) delete mode 100644 helix-db/src/helix_engine/storage_core/storage_migration.rs delete mode 100644 helix-db/src/helix_engine/storage_core/storage_migration_tests.rs delete mode 100644 helix-db/src/helix_engine/vector_core/binary_heap.rs create mode 100644 helix-db/src/helix_engine/vector_core/distance/cosine.rs create mode 100644 helix-db/src/helix_engine/vector_core/distance/mod.rs create mode 100644 helix-db/src/helix_engine/vector_core/item_iter.rs create mode 100644 helix-db/src/helix_engine/vector_core/key.rs create mode 100644 helix-db/src/helix_engine/vector_core/metadata.rs create mode 100644 helix-db/src/helix_engine/vector_core/node.rs create mode 100644 helix-db/src/helix_engine/vector_core/node_id.rs create mode 100644 helix-db/src/helix_engine/vector_core/ordered_float.rs create mode 100644 helix-db/src/helix_engine/vector_core/parallel.rs create mode 100644 helix-db/src/helix_engine/vector_core/reader.rs create mode 100644 helix-db/src/helix_engine/vector_core/spaces/mod.rs create mode 100644 helix-db/src/helix_engine/vector_core/spaces/simple.rs create mode 100644 helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs create mode 100644 helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs create mode 100644 helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs create mode 100644 helix-db/src/helix_engine/vector_core/stats.rs create mode 100644 helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs create mode 100644 helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs delete mode 100644 helix-db/src/helix_engine/vector_core/utils.rs delete mode 100644 helix-db/src/helix_engine/vector_core/vector.rs delete mode 100644 helix-db/src/helix_engine/vector_core/vector_core.rs delete mode 100644 helix-db/src/helix_engine/vector_core/vector_distance.rs delete mode 100644 helix-db/src/helix_engine/vector_core/vector_without_data.rs create mode 100644 helix-db/src/helix_engine/vector_core/version.rs create mode 100644 helix-db/src/helix_engine/vector_core/writer.rs diff --git a/Cargo.lock b/Cargo.lock index 96e0d93c9..910e7dbdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1024,6 +1024,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1261,11 +1267,22 @@ checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", "rayon", "serde", ] +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", +] + [[package]] name = "heck" version = "0.5.0" @@ -1372,10 +1389,12 @@ dependencies = [ "bincode", "bumpalo", "bytemuck", + "byteorder", "chrono", "core_affinity", "criterion", "flume", + "hashbrown 0.16.0", "heed3", "helix-macros", "helix-metrics", @@ -1383,8 +1402,12 @@ dependencies = [ "itertools 0.14.0", "lazy_static", "loom", + "madvise", "mimalloc", + "min-max-heap", "num_cpus", + "page_size", + "papaya", "paste", "pest", "pest_derive", @@ -1393,12 +1416,15 @@ dependencies = [ "rand 0.9.1", "rayon", "reqwest", + "roaring", + "rustc-hash", "serde", "sha_256", "sonic-rs", "subtle", "tempfile", "thiserror 2.0.12", + "tinyvec", "tokio", "tokio-test", "tokio-util", @@ -2060,6 +2086,15 @@ dependencies = [ "libc", ] +[[package]] +name = "madvise" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e1e75c3c34c2b34cec9f127418cb35240c7ebee5de36a51437e6b382c161b86" +dependencies = [ + "libc", +] + [[package]] name = "matchers" version = "0.2.0" @@ -2105,6 +2140,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "min-max-heap" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2687e6cf9c00f48e9284cf9fd15f2ef341d03cc7743abf9df4c5f07fdee50b18" + [[package]] name = "miniz_oxide" version = "0.8.5" @@ -2376,6 +2417,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "papaya" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f92dd0b07c53a0a0c764db2ace8c541dc47320dad97c2200c2a637ab9dd2328f" +dependencies = [ + "equivalent", + "seize", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -3593,6 +3644,16 @@ dependencies = [ "syn", ] +[[package]] +name = "roaring" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f08d6a905edb32d74a5d5737a0c9d7e950c312f3c46cb0ca0a2ca09ea11878a0" +dependencies = [ + "bytemuck", + "byteorder", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -3777,6 +3838,16 @@ dependencies = [ "libc", ] +[[package]] +name = "seize" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b55fb86dfd3a2f5f76ea78310a88f96c4ea21a3031f8d212443d56123fd0521" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "self-replace" version = "1.5.0" @@ -4358,9 +4429,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] diff --git a/helix-db/Cargo.toml b/helix-db/Cargo.toml index 4c108acf7..0957d8e53 100644 --- a/helix-db/Cargo.toml +++ b/helix-db/Cargo.toml @@ -34,7 +34,7 @@ paste = "1.0.15" rayon = "1.11.0" mimalloc = "0.1.48" bumpalo = { version = "3.19.0", features = ["collections", "boxed", "serde"] } -bytemuck = "1.24.0" +bytemuck = { version = "1.24.0", features = ["derive", "extern_crate_alloc"] } # compiler dependencies pest = { version = "2.7", optional = true } @@ -59,6 +59,15 @@ polars = { version = "0.46.0", features = [ ], optional = true } subtle = "2.6.1" sha_256 = "=0.1.1" +byteorder = "1.5.0" +roaring = "0.11.2" +tinyvec = "1.10.0" +papaya = "0.2.3" +hashbrown = "0.16.0" +min-max-heap = "1.3.0" +madvise = "0.1.0" +page_size = "0.6.0" +rustc-hash = "2.1.1" [dev-dependencies] rand = "0.9.0" diff --git a/helix-db/benches/hnsw_benches.rs b/helix-db/benches/hnsw_benches.rs index 7d8f30d3f..62667b7a4 100644 --- a/helix-db/benches/hnsw_benches.rs +++ b/helix-db/benches/hnsw_benches.rs @@ -5,18 +5,15 @@ mod tests { use helix_db::{ helix_engine::vector_core::{ hnsw::HNSW, - vector::HVector, + unaligned_vector::HVector, vector_core::{HNSWConfig, VectorCore}, }, utils::tqdm::tqdm, }; use polars::prelude::*; - use rand::{ - prelude::SliceRandom, - Rng, - }; + use rand::{Rng, prelude::SliceRandom}; use std::{ - collections::{HashSet, HashMap}, + collections::{HashMap, HashSet}, fs::{self, File}, sync::{Arc, Mutex}, thread, @@ -88,26 +85,23 @@ mod tests { .map(|dist| (base_vec.id.clone(), dist)) .ok() }) - .collect(); + .collect(); distances.sort_by(|a, b| { a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) }); - let top_k_ids: Vec = distances - .into_iter() - .take(k) - .map(|(id, _)| id) - .collect(); + let top_k_ids: Vec = + distances.into_iter().take(k).map(|(id, _)| id).collect(); (query_id, top_k_ids) }) - .collect(); + .collect(); results.lock().unwrap().extend(local_results); }) }) - .collect(); + .collect(); for handle in handles { handle.join().unwrap(); @@ -354,7 +348,9 @@ mod tests { let mut total_search_time = std::time::Duration::from_secs(0); for (qid, query) in query_vectors.iter() { let start_time = Instant::now(); - let results = index.search::(&txn, query, k, "vector", None, false).unwrap(); + let results = index + .search::(&txn, query, k, "vector", None, false) + .unwrap(); let search_duration = start_time.elapsed(); total_search_time += search_duration; @@ -400,4 +396,3 @@ mod tests { } // TODO: memory benchmark (only the hnsw index ofc) - diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 512ffe97d..5709e9975 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -1,10 +1,6 @@ use crate::{ debug_println, - helix_engine::{ - storage_core::HelixGraphStorage, - types::GraphError, - vector_core::{hnsw::HNSW, vector::HVector}, - }, + helix_engine::{storage_core::HelixGraphStorage, types::GraphError}, utils::properties::ImmutablePropertiesMap, }; @@ -418,27 +414,20 @@ impl HybridSearch for HelixGraphStorage { } }); - let vector_handle = task::spawn_blocking( - move || -> Result>, GraphError> { + let vector_handle = + task::spawn_blocking(move || -> Result>, GraphError> { let txn = graph_env_vector.read_txn()?; - let arena = Bump::new(); // MOVE + let arena = Bump::new(); // MOVE let query_slice = arena.alloc_slice_copy(query_vector_owned.as_slice()); - let results = self.vectors.search:: bool>( - &txn, - query_slice, - limit * 2, - "vector", - None, - false, - &arena, - )?; + let results = + self.vectors + .search(&txn, query_slice, limit * 2, "vector", false, &arena)?; let scores = results .into_iter() .map(|vec| (vec.id, vec.distance.unwrap_or(0.0))) .collect::>(); Ok(Some(scores)) - }, - ); + }); let (bm25_results, vector_results) = match tokio::try_join!(bm25_handle, vector_handle) { Ok((a, b)) => (a, b), diff --git a/helix-db/src/helix_engine/bm25/bm25_tests.rs b/helix-db/src/helix_engine/bm25/bm25_tests.rs index 48004eca7..97c8e8535 100644 --- a/helix-db/src/helix_engine/bm25/bm25_tests.rs +++ b/helix-db/src/helix_engine/bm25/bm25_tests.rs @@ -7,14 +7,13 @@ mod tests { }, storage_core::{HelixGraphStorage, version_info::VersionInfo}, traversal_core::config::Config, - vector_core::{hnsw::HNSW, vector::HVector}, }, protocol::value::Value, utils::properties::ImmutablePropertiesMap, }; use bumpalo::Bump; - use heed3::{Env, EnvOpenOptions, RoTxn}; + use heed3::{Env, EnvOpenOptions}; use rand::Rng; use std::collections::HashMap; use tempfile::tempdir; @@ -203,7 +202,9 @@ mod tests { for (i, props) in nodes.iter().enumerate() { let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); let data = props_map.flatten_bm25(); @@ -271,7 +272,9 @@ mod tests { for (i, props) in nodes.iter().enumerate() { let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); let data = props_map.flatten_bm25(); @@ -1258,7 +1261,9 @@ mod tests { for (i, props) in nodes.iter().enumerate() { let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); let data = props_map.flatten_bm25(); @@ -1448,7 +1453,7 @@ mod tests { let slice = arena.alloc_slice_copy(vec.as_slice()); let _ = storage .vectors - .insert:: bool>(&mut wtxn, "vector", slice, None, &arena); + .insert(&mut wtxn, "vector", slice, None, &arena); arena.reset(); } wtxn.commit().unwrap(); @@ -1493,7 +1498,7 @@ mod tests { let slice = arena.alloc_slice_copy(vec.as_slice()); let _ = storage .vectors - .insert:: bool>(&mut wtxn, "vector", slice, None, &arena); + .insert(&mut wtxn, "vector", slice, None, &arena); arena.reset(); } wtxn.commit().unwrap(); @@ -1539,7 +1544,7 @@ mod tests { let slice = arena.alloc_slice_copy(vec.as_slice()); let _ = storage .vectors - .insert:: bool>(&mut wtxn, "vector", slice, None, &arena); + .insert(&mut wtxn, "vector", slice, None, &arena); arena.reset(); } wtxn.commit().unwrap(); diff --git a/helix-db/src/helix_engine/reranker/adapters/mod.rs b/helix-db/src/helix_engine/reranker/adapters/mod.rs index 49cf5abb6..145ad5e68 100644 --- a/helix-db/src/helix_engine/reranker/adapters/mod.rs +++ b/helix-db/src/helix_engine/reranker/adapters/mod.rs @@ -12,7 +12,6 @@ //! .collect_to::>() //! ``` - use crate::helix_engine::{ reranker::reranker::Reranker, traversal_core::{traversal_iter::RoTraversalIterator, traversal_value::TraversalValue}, @@ -25,7 +24,9 @@ pub struct RerankIterator<'arena, I: Iterator, GraphError>>> Iterator for RerankIterator<'arena, I> { +impl<'arena, I: Iterator, GraphError>>> Iterator + for RerankIterator<'arena, I> +{ type Item = Result, GraphError>; fn next(&mut self) -> Option { @@ -34,7 +35,8 @@ impl<'arena, I: Iterator, GraphError>>> Ite } /// Trait that adds reranking capability to traversal iterators. -pub trait RerankAdapter<'arena, 'db, 'txn>: Iterator, GraphError>> +pub trait RerankAdapter<'arena, 'db, 'txn>: + Iterator, GraphError>> where 'db: 'arena, 'arena: 'txn, @@ -61,7 +63,12 @@ where self, reranker: R, query: Option<&str>, - ) -> RoTraversalIterator<'db, 'arena, 'txn, impl Iterator, GraphError>>>; + ) -> RoTraversalIterator< + 'db, + 'arena, + 'txn, + impl Iterator, GraphError>>, + >; } impl<'db, 'arena, 'txn, I> RerankAdapter<'arena, 'db, 'txn> @@ -75,7 +82,12 @@ where self, reranker: R, query: Option<&str>, - ) -> RoTraversalIterator<'db, 'arena, 'txn, impl Iterator, GraphError>>> { + ) -> RoTraversalIterator< + 'db, + 'arena, + 'txn, + impl Iterator, GraphError>>, + > { // Collect all items from the iterator let items = self.inner.filter_map(|item| item.ok()); @@ -106,7 +118,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::helix_engine::{reranker::fusion::RRFReranker, vector_core::vector::HVector}; + use crate::helix_engine::{reranker::fusion::RRFReranker, vector_core::HVector}; #[test] fn test_rerank_adapter_trait() { @@ -122,8 +134,12 @@ mod tests { let data1 = arena.alloc_slice_copy(&[1.0]); let data2 = arena.alloc_slice_copy(&[2.0]); let items = vec![ - Ok(TraversalValue::Vector(HVector::from_slice("test", 0, data1))), - Ok(TraversalValue::Vector(HVector::from_slice("test", 0, data2))), + Ok(TraversalValue::Vector(HVector::from_slice( + "test", 0, data1, &arena, + ))), + Ok(TraversalValue::Vector(HVector::from_slice( + "test", 0, data2, &arena, + ))), ]; let mut iter = RerankIterator { diff --git a/helix-db/src/helix_engine/reranker/fusion/mmr.rs b/helix-db/src/helix_engine/reranker/fusion/mmr.rs index 85623443b..15316802f 100644 --- a/helix-db/src/helix_engine/reranker/fusion/mmr.rs +++ b/helix-db/src/helix_engine/reranker/fusion/mmr.rs @@ -10,14 +10,12 @@ //! - Sim2: similarity to already selected documents (diversity) //! - λ: trade-off parameter (typically 0.5-0.8) -use crate::{ - helix_engine::{ - reranker::{ - errors::{RerankerError, RerankerResult}, - reranker::{extract_score, update_score, Reranker}, - }, - traversal_core::traversal_value::TraversalValue, +use crate::helix_engine::{ + reranker::{ + errors::{RerankerError, RerankerResult}, + reranker::{Reranker, extract_score, update_score}, }, + traversal_core::traversal_value::TraversalValue, }; use std::collections::HashMap; @@ -85,11 +83,16 @@ impl MMRReranker { /// Extract vector data from a TraversalValue. /// Note: This requires an arena to convert VectorPrecisionData to f64 slice - fn extract_vector_data<'a>(&self, item: &'a TraversalValue<'a>, _arena: &'a bumpalo::Bump) -> RerankerResult<&'a [f64]> { + fn extract_vector_data<'a>( + &self, + item: &'a TraversalValue<'a>, + arena: &'a bumpalo::Bump, + ) -> RerankerResult> { match item { - TraversalValue::Vector(v) => Ok(v.data), + TraversalValue::Vector(v) => Ok(v.data(arena).to_vec()), _ => Err(RerankerError::TextExtractionError( - "Cannot extract vector from this item type (only Vector supported for MMR)".to_string(), + "Cannot extract vector from this item type (only Vector supported for MMR)" + .to_string(), )), } } @@ -134,7 +137,10 @@ impl MMRReranker { } /// Perform MMR selection on the given items. - fn mmr_select<'arena>(&self, items: Vec>) -> RerankerResult>> { + fn mmr_select<'arena>( + &self, + items: Vec>, + ) -> RerankerResult>> { // Create a temporary arena for vector conversions let arena = bumpalo::Bump::new(); if items.is_empty() { @@ -169,7 +175,7 @@ impl MMRReranker { // Calculate relevance term let relevance = if let Some(query) = &self.query_vector { - self.calculate_similarity(item_vec, query)? + self.calculate_similarity(&item_vec, query)? } else { *relevance_score // Use original score as relevance }; @@ -183,7 +189,7 @@ impl MMRReranker { cached } else { let sel_vec = self.extract_vector_data(selected_item, &arena)?; - let sim = self.calculate_similarity(item_vec, sel_vec)?; + let sim = self.calculate_similarity(&item_vec, &sel_vec)?; similarity_cache.insert(cache_key, sim); sim }; @@ -211,7 +217,11 @@ impl MMRReranker { } impl Reranker for MMRReranker { - fn rerank<'arena, I>(&self, items: I, _query: Option<&str>) -> RerankerResult>> + fn rerank<'arena, I>( + &self, + items: I, + _query: Option<&str>, + ) -> RerankerResult>> where I: Iterator>, { @@ -227,12 +237,12 @@ impl Reranker for MMRReranker { #[cfg(test)] mod tests { use super::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use bumpalo::Bump; fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice) + HVector::from_slice("test_vector", 0, slice, arena) } #[test] @@ -339,9 +349,7 @@ mod tests { fn test_mmr_with_query_vector() { let arena = Bump::new(); let query = vec![1.0, 0.0, 0.0]; - let mmr = MMRReranker::new(0.7) - .unwrap() - .with_query_vector(query); + let mmr = MMRReranker::new(0.7).unwrap().with_query_vector(query); let vectors: Vec = vec![ { @@ -683,10 +691,10 @@ mod tests { // Verify vector data is preserved if let TraversalValue::Vector(v) = &results[0] { - assert_eq!(v.data, &[1.5, 2.5, 3.5]); + assert_eq!(v.data_borrowed(), &[1.5, 2.5, 3.5]); } if let TraversalValue::Vector(v) = &results[1] { - assert_eq!(v.data, &[4.5, 5.5, 6.5]); + assert_eq!(v.data_borrowed(), &[4.5, 5.5, 6.5]); } } @@ -719,9 +727,7 @@ mod tests { fn test_mmr_with_query_vector_relevance() { let arena = Bump::new(); let query = vec![1.0, 0.0]; - let mmr = MMRReranker::new(0.9) - .unwrap() - .with_query_vector(query); + let mmr = MMRReranker::new(0.9).unwrap().with_query_vector(query); let vectors: Vec = vec![ { diff --git a/helix-db/src/helix_engine/reranker/fusion/rrf.rs b/helix-db/src/helix_engine/reranker/fusion/rrf.rs index 419329582..b93703497 100644 --- a/helix-db/src/helix_engine/reranker/fusion/rrf.rs +++ b/helix-db/src/helix_engine/reranker/fusion/rrf.rs @@ -7,14 +7,12 @@ //! Formula: RRF_score(d) = Σ 1/(k + rank_i(d)) //! where k is typically 60 (default). -use crate::{ - helix_engine::{ - reranker::{ - errors::{RerankerError, RerankerResult}, - reranker::{update_score, Reranker}, - }, - traversal_core::traversal_value::TraversalValue, +use crate::helix_engine::{ + reranker::{ + errors::{RerankerError, RerankerResult}, + reranker::{Reranker, update_score}, }, + traversal_core::traversal_value::TraversalValue, }; use std::collections::HashMap; @@ -55,7 +53,10 @@ impl RRFReranker { /// /// # Returns /// A vector of items reranked by RRF scores - pub fn fuse_lists<'arena, I>(lists: Vec, k: f64) -> RerankerResult>> + pub fn fuse_lists<'arena, I>( + lists: Vec, + k: f64, + ) -> RerankerResult>> where I: Iterator>, { @@ -112,7 +113,11 @@ impl Default for RRFReranker { } impl Reranker for RRFReranker { - fn rerank<'arena, I>(&self, items: I, _query: Option<&str>) -> RerankerResult>> + fn rerank<'arena, I>( + &self, + items: I, + _query: Option<&str>, + ) -> RerankerResult>> where I: Iterator>, { @@ -143,15 +148,12 @@ impl Reranker for RRFReranker { #[cfg(test)] mod tests { use super::*; - use crate::{ - helix_engine::vector_core::vector::HVector, - utils::items::Node, - }; + use crate::{helix_engine::vector_core::HVector, utils::items::Node}; use bumpalo::Bump; fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice) + HVector::from_slice("test_vector", 0, slice, arena) } #[test] @@ -242,11 +244,8 @@ mod tests { }, ]; - let results = RRFReranker::fuse_lists( - vec![list1.into_iter(), list2.into_iter()], - 60.0, - ) - .unwrap(); + let results = + RRFReranker::fuse_lists(vec![list1.into_iter(), list2.into_iter()], 60.0).unwrap(); // Items 1 and 2 appear in both lists, so should have higher scores assert_eq!(results.len(), 4); @@ -280,7 +279,8 @@ mod tests { #[test] fn test_rrf_fuse_empty_lists() { - let result = RRFReranker::fuse_lists(Vec::>::new(), 60.0); + let result = + RRFReranker::fuse_lists(Vec::>::new(), 60.0); assert!(result.is_err()); } @@ -400,17 +400,15 @@ mod tests { }, ]; - let results = RRFReranker::fuse_lists( - vec![list1.into_iter(), list2.into_iter()], - 60.0, - ) - .unwrap(); + let results = + RRFReranker::fuse_lists(vec![list1.into_iter(), list2.into_iter()], 60.0).unwrap(); // All items should be present with equal RRF scores for same ranks assert_eq!(results.len(), 4); // Items at rank 0 in their respective lists should have same score - if let (TraversalValue::Vector(v1), TraversalValue::Vector(v2)) = (&results[0], &results[1]) { + if let (TraversalValue::Vector(v1), TraversalValue::Vector(v2)) = (&results[0], &results[1]) + { let score1 = v1.distance.unwrap(); let score2 = v2.distance.unwrap(); assert!((score1 - score2).abs() < 1e-10); @@ -542,11 +540,8 @@ mod tests { }) .collect(); - let results = RRFReranker::fuse_lists( - vec![list1.into_iter(), list2.into_iter()], - 60.0, - ) - .unwrap(); + let results = + RRFReranker::fuse_lists(vec![list1.into_iter(), list2.into_iter()], 60.0).unwrap(); // Items 5, 6, 7 appear in both lists, should rank higher assert_eq!(results.len(), 10); @@ -573,7 +568,9 @@ mod tests { // Scores should be monotonically decreasing for i in 0..results.len() - 1 { - if let (TraversalValue::Vector(v1), TraversalValue::Vector(v2)) = (&results[i], &results[i + 1]) { + if let (TraversalValue::Vector(v1), TraversalValue::Vector(v2)) = + (&results[i], &results[i + 1]) + { assert!(v1.distance.unwrap() >= v2.distance.unwrap()); } } @@ -602,10 +599,10 @@ mod tests { // Verify vector data is preserved if let TraversalValue::Vector(v) = &results[0] { - assert_eq!(v.data, &[0.0, 0.0]); + assert_eq!(v.data_borrowed(), &[0.0, 0.0]); } if let TraversalValue::Vector(v) = &results[1] { - assert_eq!(v.data, &[1.0, 2.0]); + assert_eq!(v.data_borrowed(), &[1.0, 2.0]); } } } diff --git a/helix-db/src/helix_engine/reranker/models/cross_encoder.rs b/helix-db/src/helix_engine/reranker/models/cross_encoder.rs index f39698a97..db3feb74b 100644 --- a/helix-db/src/helix_engine/reranker/models/cross_encoder.rs +++ b/helix-db/src/helix_engine/reranker/models/cross_encoder.rs @@ -88,7 +88,6 @@ impl CrossEncoderReranker { TraversalValue::Node(n) => n.properties, TraversalValue::Edge(e) => e.properties, TraversalValue::Vector(v) => v.properties, - TraversalValue::VectorNodeWithoutVectorData(v) => v.properties, TraversalValue::NodeWithScore { node, .. } => node.properties, _ => None, }; @@ -131,7 +130,11 @@ impl CrossEncoderReranker { } impl Reranker for CrossEncoderReranker { - fn rerank<'arena, I>(&self, items: I, query: Option<&str>) -> RerankerResult>> + fn rerank<'arena, I>( + &self, + items: I, + query: Option<&str>, + ) -> RerankerResult>> where I: Iterator>, { @@ -169,12 +172,12 @@ impl Reranker for CrossEncoderReranker { #[cfg(test)] mod tests { use super::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use bumpalo::Bump; fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice) + HVector::from_slice("test_vector", 0, slice, arena) } #[ignore] @@ -213,7 +216,7 @@ mod tests { let result = reranker.extract_text(&item); assert!(result.is_err()); } - + #[ignore] #[test] fn test_rerank_without_query() { @@ -221,7 +224,8 @@ mod tests { let config = CrossEncoderConfig::new("test-model"); let reranker = CrossEncoderReranker::new(config); - let vectors: Vec = vec![TraversalValue::Vector(alloc_vector(&arena, &[1.0]))]; + let vectors: Vec = + vec![TraversalValue::Vector(alloc_vector(&arena, &[1.0]))]; let result = reranker.rerank(vectors.into_iter(), None); assert!(result.is_err()); diff --git a/helix-db/src/helix_engine/storage_core/graph_visualization.rs b/helix-db/src/helix_engine/storage_core/graph_visualization.rs index 510d7f6dd..89eaad5bf 100644 --- a/helix-db/src/helix_engine/storage_core/graph_visualization.rs +++ b/helix-db/src/helix_engine/storage_core/graph_visualization.rs @@ -3,8 +3,8 @@ use crate::{ helix_engine::{storage_core::HelixGraphStorage, types::GraphError}, utils::items::Node, }; -use heed3::{types::*, RoIter, RoTxn}; -use sonic_rs::{json, JsonValueMutTrait, Value as JsonValue}; +use heed3::{RoIter, RoTxn, types::*}; +use sonic_rs::{JsonValueMutTrait, Value as JsonValue, json}; use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap}, @@ -40,9 +40,7 @@ impl GraphVisualization for HelixGraphStorage { } if self.nodes_db.is_empty(txn)? || self.edges_db.is_empty(txn)? { - return Err(GraphError::New( - "edges or nodes db is empty!".to_string(), - )); + return Err(GraphError::New("edges or nodes db is empty!".to_string())); } let top_nodes = self.get_nodes_by_cardinality(txn, k)?; @@ -55,7 +53,7 @@ impl GraphVisualization for HelixGraphStorage { let result = json!({ "num_nodes": self.nodes_db.len(txn).unwrap_or(0), "num_edges": self.edges_db.len(txn).unwrap_or(0), - "num_vectors": self.vectors.vectors_db.len(txn).unwrap_or(0), + "num_vectors": self.vectors.stats.num_vectors, }); debug_println!("db stats json: {:?}", result); @@ -133,11 +131,7 @@ impl HelixGraphStorage { BinaryHeap::with_capacity(node_count as usize); // out edges - iterate through nodes by getting each unique node ID from out_edges_db - let out_node_key_iter = out_db - .out_edges_db - .lazily_decode_data() - .iter(txn) - .unwrap(); + let out_node_key_iter = out_db.out_edges_db.lazily_decode_data().iter(txn).unwrap(); for data in out_node_key_iter { match data { Ok((key, _)) => { @@ -260,18 +254,17 @@ impl HelixGraphStorage { if let Some(node_data) = self.nodes_db.get(txn, id)? { let node = Node::from_bincode_bytes(*id, node_data, &arena)?; if let Some(props) = node.properties - && let Some(prop_value) = props.get(prop) { - json_node - .as_object_mut() - .ok_or_else(|| { - GraphError::New("invalid JSON object".to_string()) - })? - .insert( - "label", - sonic_rs::to_value(&prop_value.inner_stringify()) - .unwrap_or_else(|_| sonic_rs::Value::from("")), - ); - } + && let Some(prop_value) = props.get(prop) + { + json_node + .as_object_mut() + .ok_or_else(|| GraphError::New("invalid JSON object".to_string()))? + .insert( + "label", + sonic_rs::to_value(&prop_value.inner_stringify()) + .unwrap_or_else(|_| sonic_rs::Value::from("")), + ); + } } } diff --git a/helix-db/src/helix_engine/storage_core/mod.rs b/helix-db/src/helix_engine/storage_core/mod.rs index 7666fd401..ba1229e22 100644 --- a/helix-db/src/helix_engine/storage_core/mod.rs +++ b/helix-db/src/helix_engine/storage_core/mod.rs @@ -1,11 +1,8 @@ pub mod graph_visualization; pub mod metadata; pub mod storage_methods; -pub mod storage_migration; pub mod version_info; -#[cfg(test)] -mod storage_migration_tests; #[cfg(test)] mod storage_concurrent_tests; @@ -18,10 +15,7 @@ use crate::{ }, traversal_core::config::Config, types::GraphError, - vector_core::{ - hnsw::HNSW, - vector_core::{HNSWConfig, VectorCore}, - }, + vector_core::{HNSWConfig, VectorCore}, }, utils::{ items::{Edge, Node}, @@ -179,7 +173,7 @@ impl HelixGraphStorage { wtxn.commit()?; - let mut storage = Self { + let storage = Self { graph_env, nodes_db, edges_db, @@ -193,8 +187,6 @@ impl HelixGraphStorage { version_info, }; - storage_migration::migrate(&mut storage)?; - Ok(storage) } diff --git a/helix-db/src/helix_engine/storage_core/storage_migration.rs b/helix-db/src/helix_engine/storage_core/storage_migration.rs deleted file mode 100644 index eb5f3da7a..000000000 --- a/helix-db/src/helix_engine/storage_core/storage_migration.rs +++ /dev/null @@ -1,473 +0,0 @@ -use crate::{ - helix_engine::{ - storage_core::HelixGraphStorage, - types::GraphError, - vector_core::{vector::HVector, vector_core}, - }, - protocol::value::Value, - utils::properties::ImmutablePropertiesMap, -}; -use bincode::Options; -use itertools::Itertools; -use std::{collections::HashMap, ops::Bound}; - -use super::metadata::{NATIVE_VECTOR_ENDIANNESS, StorageMetadata, VectorEndianness}; - -pub fn migrate(storage: &mut HelixGraphStorage) -> Result<(), GraphError> { - let mut metadata = { - let txn = storage.graph_env.read_txn()?; - StorageMetadata::read(&txn, &storage.metadata_db)? - }; - - loop { - metadata = match metadata { - StorageMetadata::PreMetadata => { - migrate_pre_metadata_to_native_vector_endianness(storage)? - } - StorageMetadata::VectorNativeEndianness { - vector_endianness: NATIVE_VECTOR_ENDIANNESS, - } => { - // If the vectors are in the native vector endianness, we're done migrating them - break; - } - StorageMetadata::VectorNativeEndianness { - vector_endianness: currently_stored_vector_endianness, - } => convert_vectors_to_native_endianness(currently_stored_vector_endianness, storage)?, - }; - } - - verify_vectors_and_repair(storage)?; - remove_orphaned_vector_edges(storage)?; - - Ok(()) -} - -pub(crate) fn migrate_pre_metadata_to_native_vector_endianness( - storage: &mut HelixGraphStorage, -) -> Result { - // In PreMetadata, all vectors are stored as big endian. - // If we are on a big endian machine, all we need to do is store the metadata. - // Otherwise, we need to convert all the vectors and then store the metadata. - - let metadata = StorageMetadata::VectorNativeEndianness { - vector_endianness: NATIVE_VECTOR_ENDIANNESS, - }; - - #[cfg(target_endian = "little")] - { - // On little-endian machines, we need to convert from big-endian to little-endian - convert_all_vectors(VectorEndianness::BigEndian, storage)?; - } - - convert_all_vector_properties(storage)?; - - // Save the metadata - let mut txn = storage.graph_env.write_txn()?; - metadata.save(&mut txn, &storage.metadata_db)?; - txn.commit()?; - - Ok(metadata) -} - -pub(crate) fn convert_vectors_to_native_endianness( - currently_stored_vector_endianness: VectorEndianness, - storage: &mut HelixGraphStorage, -) -> Result { - // Convert all vectors from currently_stored_vector_endianness to native endianness - convert_all_vectors(currently_stored_vector_endianness, storage)?; - - let metadata = StorageMetadata::VectorNativeEndianness { - vector_endianness: NATIVE_VECTOR_ENDIANNESS, - }; - - // Save the updated metadata - let mut txn = storage.graph_env.write_txn()?; - metadata.save(&mut txn, &storage.metadata_db)?; - txn.commit()?; - - Ok(metadata) -} - -pub(crate) fn convert_all_vectors( - source_endianness: VectorEndianness, - storage: &mut HelixGraphStorage, -) -> Result<(), GraphError> { - const BATCH_SIZE: usize = 1024; - - let key_arena = bumpalo::Bump::new(); - let batch_bounds = { - let mut keys = vec![]; - - let txn = storage.graph_env.read_txn()?; - - for (i, kv) in storage - .vectors - .vectors_db - .lazily_decode_data() - .iter(&txn)? - .enumerate() - { - let (key, _) = kv?; - - if i % BATCH_SIZE == 0 { - let key: &[u8] = key_arena.alloc_slice_copy(key); - keys.push(key); - } - } - - let mut ranges = vec![]; - for (start, end) in keys.iter().copied().tuple_windows() { - ranges.push((Bound::Included(start), Bound::Excluded(end))); - } - ranges.extend( - keys.last() - .copied() - .map(|last_batch_end| (Bound::Included(last_batch_end), Bound::Unbounded)), - ); - - ranges - }; - - for bounds in batch_bounds { - let arena = bumpalo::Bump::new(); - - let mut txn = storage.graph_env.write_txn()?; - let mut cursor = storage.vectors.vectors_db.range_mut(&mut txn, &bounds)?; - - while let Some((key, value)) = cursor.next().transpose()? { - if key == vector_core::ENTRY_POINT_KEY { - continue; - } - - let value = convert_vector_endianness(value, source_endianness, &arena)?; - - let success = unsafe { cursor.put_current(key, value)? }; - if !success { - return Err(GraphError::New("failed to update value in LMDB".into())); - } - } - drop(cursor); - - txn.commit()?; - } - - Ok(()) -} - -/// Converts a single vector's endianness by reading f64 values in source endianness -/// and writing them in native endianness. Uses arena for allocations. -pub(crate) fn convert_vector_endianness<'arena>( - bytes: &[u8], - source_endianness: VectorEndianness, - arena: &'arena bumpalo::Bump, -) -> Result<&'arena [u8], GraphError> { - use std::{alloc, mem, ptr, slice}; - - if bytes.is_empty() { - // We use unsafe stuff below so best not to risk allocating a layout of size zero etc - return Ok(&[]); - } - - if !bytes.len().is_multiple_of(mem::size_of::()) { - return Err(GraphError::New( - "Vector data length is not a multiple of f64 size".to_string(), - )); - } - - let num_floats = bytes.len() / mem::size_of::(); - - // Allocate space for the converted f64 array in the arena - let layout = alloc::Layout::array::(num_floats) - .map_err(|_| GraphError::New("Failed to create array layout".to_string()))?; - - let data_ptr: ptr::NonNull = arena.alloc_layout(layout); - - let converted_floats: &'arena [f64] = unsafe { - let float_ptr: ptr::NonNull = data_ptr.cast(); - let float_slice = slice::from_raw_parts_mut(float_ptr.as_ptr(), num_floats); - - // Read each f64 in the source endianness and write in native endianness - for (i, float) in float_slice.iter_mut().enumerate() { - let start = i * mem::size_of::(); - let end = start + mem::size_of::(); - let float_bytes: [u8; 8] = bytes[start..end] - .try_into() - .map_err(|_| GraphError::New("Failed to extract f64 bytes".to_string()))?; - - let value = match source_endianness { - VectorEndianness::BigEndian => f64::from_be_bytes(float_bytes), - VectorEndianness::LittleEndian => f64::from_le_bytes(float_bytes), - }; - - *float = value; - } - - slice::from_raw_parts(float_ptr.as_ptr(), num_floats) - }; - - // Convert to bytes using bytemuck - let result_bytes: &[u8] = bytemuck::cast_slice(converted_floats); - - Ok(result_bytes) -} - -pub(crate) fn convert_all_vector_properties( - storage: &mut HelixGraphStorage, -) -> Result<(), GraphError> { - const BATCH_SIZE: usize = 1024; - - let batch_bounds = { - let txn = storage.graph_env.read_txn()?; - let mut keys = vec![]; - - for (i, kv) in storage - .vectors - .vector_properties_db - .lazily_decode_data() - .iter(&txn)? - .enumerate() - { - let (key, _) = kv?; - - if i % BATCH_SIZE == 0 { - keys.push(key); - } - } - - let mut ranges = vec![]; - for (start, end) in keys.iter().copied().tuple_windows() { - ranges.push((Bound::Included(start), Bound::Excluded(end))); - } - ranges.extend( - keys.last() - .copied() - .map(|last_batch_end| (Bound::Included(last_batch_end), Bound::Unbounded)), - ); - - ranges - }; - - for bounds in batch_bounds { - let arena = bumpalo::Bump::new(); - - let mut txn = storage.graph_env.write_txn()?; - let mut cursor = storage - .vectors - .vector_properties_db - .range_mut(&mut txn, &bounds)?; - - while let Some((key, value)) = cursor.next().transpose()? { - let value = convert_old_vector_properties_to_new_format(value, &arena)?; - - let success = unsafe { cursor.put_current(&key, &value)? }; - if !success { - return Err(GraphError::New("failed to update value in LMDB".into())); - } - } - drop(cursor); - - txn.commit()?; - } - - Ok(()) -} - -pub(crate) fn convert_old_vector_properties_to_new_format( - property_bytes: &[u8], - arena: &bumpalo::Bump, -) -> Result, GraphError> { - let mut old_properties: HashMap = bincode::DefaultOptions::new() - .with_fixint_encoding() - .allow_trailing_bytes() - .deserialize(property_bytes)?; - - let label = old_properties - .remove("label") - .expect("all old vectors should have label"); - let is_deleted = old_properties - .remove("is_deleted") - .expect("all old vectors should have deleted"); - - let new_properties = ImmutablePropertiesMap::new( - old_properties.len(), - old_properties.iter().map(|(k, v)| (k.as_str(), v.clone())), - arena, - ); - - let new_vector = HVector { - id: 0u128, - label: &label.inner_stringify(), - version: 0, - deleted: is_deleted == true, - level: 0, - distance: None, - data: &[], - properties: Some(new_properties), - }; - - new_vector.to_bincode_bytes().map_err(GraphError::from) -} - -fn verify_vectors_and_repair(storage: &HelixGraphStorage) -> Result<(), GraphError> { - // Verify that all vectors at level > 0 also exist at level 0 and collect ones that need repair - println!("\nVerifying vector integrity after migration..."); - let vectors_to_repair: Vec<(u128, usize)> = { - let txn = storage.graph_env.read_txn()?; - let mut missing = Vec::new(); - - for kv in storage.vectors.vectors_db.iter(&txn)? { - let (key, _) = kv?; - if key.starts_with(b"v:") && key.len() >= 26 { - let id = u128::from_be_bytes(key[2..18].try_into().unwrap()); - let level = usize::from_be_bytes(key[18..26].try_into().unwrap()); - - if level > 0 { - // Check if level 0 exists - let level_0_key = vector_core::VectorCore::vector_key(id, 0); - if storage - .vectors - .vectors_db - .get(&txn, &level_0_key)? - .is_none() - { - println!( - "ERROR: Vector {} exists at level {} but NOT at level 0!", - uuid::Uuid::from_u128(id), - level - ); - missing.push((id, level)); - } - } - } - } - missing - }; - - if !vectors_to_repair.is_empty() { - println!( - "Found {} vectors at level > 0 missing their level 0 counterparts!", - vectors_to_repair.len() - ); - println!("Repairing missing level 0 vectors..."); - - const REPAIR_BATCH_SIZE: usize = 128; - - // Process repairs in batches - for batch in vectors_to_repair.chunks(REPAIR_BATCH_SIZE) { - let mut txn = storage.graph_env.write_txn()?; - - let key_arena = bumpalo::Bump::new(); - - for &(id, source_level) in batch { - // Read vector data from source level - let source_key = vector_core::VectorCore::vector_key(id, source_level); - let vector_data: &[u8] = { - let key = storage - .vectors - .vectors_db - .get(&txn, &source_key)? - .ok_or_else(|| { - GraphError::New(format!( - "Could not read vector {} at level {source_level} for repair", - uuid::Uuid::from_u128(id) - )) - })?; - key_arena.alloc_slice_copy(key) - }; - - // Write to level 0 - let level_0_key = vector_core::VectorCore::vector_key(id, 0); - storage - .vectors - .vectors_db - .put(&mut txn, &level_0_key, vector_data)?; - println!( - " Repaired: Copied vector {} from level {} to level 0", - uuid::Uuid::from_u128(id), - source_level - ); - } - - txn.commit()?; - } - - println!( - "Repair complete! Repaired {} vectors.", - vectors_to_repair.len() - ); - } else { - println!("All vectors verified successfully!"); - } - - Ok(()) -} - -fn remove_orphaned_vector_edges(storage: &HelixGraphStorage) -> Result<(), GraphError> { - let txn = storage.graph_env.read_txn()?; - let mut orphaned_edges = Vec::new(); - - for kv in storage.vectors.edges_db.iter(&txn)? { - let (key, _) = kv?; - - // Edge key format: [source_id (16 bytes), level (8 bytes), sink_id (16 bytes)] - // Total: 40 bytes - if key.len() != 40 { - println!( - "WARNING: Vector edge key has unexpected length: {} bytes", - key.len() - ); - continue; - } - - // Extract source_id - let source_id = u128::from_be_bytes(key[0..16].try_into().unwrap()); - - // Extract level - let level = usize::from_be_bytes(key[16..24].try_into().unwrap()); - - // Extract sink_id - let sink_id = u128::from_be_bytes(key[24..40].try_into().unwrap()); - - // Check if source vector exists at level 0 - let source_key = vector_core::VectorCore::vector_key(source_id, 0); - let source_exists = storage.vectors.vectors_db.get(&txn, &source_key)?.is_some(); - - // Check if sink vector exists at level 0 - let sink_key = vector_core::VectorCore::vector_key(sink_id, 0); - let sink_exists = storage.vectors.vectors_db.get(&txn, &sink_key)?.is_some(); - - if !source_exists || !sink_exists { - orphaned_edges.push(( - uuid::Uuid::from_u128(source_id), - level, - uuid::Uuid::from_u128(sink_id), - )); - } - } - - for chunk in orphaned_edges.into_iter().chunks(64).into_iter() { - let mut txn = storage.graph_env.write_txn()?; - - for (source_id, level, sink_id) in chunk { - let edge_key = vector_core::VectorCore::out_edges_key( - source_id.as_u128(), - level, - Some(sink_id.as_u128()), - ); - - storage - .vectors - .edges_db - .get(&txn, &edge_key)? - .ok_or_else(|| { - GraphError::New("edge key doesnt exist when removing orphan".into()) - })?; - - storage.vectors.edges_db.delete(&mut txn, &edge_key)?; - } - - txn.commit()?; - } - - Ok(()) -} diff --git a/helix-db/src/helix_engine/storage_core/storage_migration_tests.rs b/helix-db/src/helix_engine/storage_core/storage_migration_tests.rs deleted file mode 100644 index 3b7e3461f..000000000 --- a/helix-db/src/helix_engine/storage_core/storage_migration_tests.rs +++ /dev/null @@ -1,1037 +0,0 @@ -//! Comprehensive test suite for storage_migration.rs -//! -//! This test module covers: -//! - Unit tests for endianness conversion functions -//! - Unit tests for property conversion functions -//! - Integration tests for full migration scenarios -//! - Property-based tests for correctness validation -//! - Error handling tests for failure modes -//! - Performance tests for large datasets - -use super::{ - metadata::{StorageMetadata, VectorEndianness, NATIVE_VECTOR_ENDIANNESS}, - storage_migration::{ - convert_all_vector_properties, convert_old_vector_properties_to_new_format, - convert_vector_endianness, migrate, - }, - HelixGraphStorage, -}; -use crate::{ - helix_engine::{ - storage_core::version_info::VersionInfo, traversal_core::config::Config, - types::GraphError, - }, - protocol::value::Value, -}; -use std::collections::HashMap; -use tempfile::TempDir; - -// ============================================================================ -// Test Utilities and Fixtures -// ============================================================================ - -/// Helper function to create a test storage instance -fn setup_test_storage() -> (HelixGraphStorage, TempDir) { - let temp_dir = TempDir::new().unwrap(); - let config = Config::default(); - let version_info = VersionInfo::default(); - - let storage = - HelixGraphStorage::new(temp_dir.path().to_str().unwrap(), config, version_info).unwrap(); - - (storage, temp_dir) -} - -/// Create test vector data in a specific endianness -fn create_test_vector_bytes(values: &[f64], endianness: VectorEndianness) -> Vec { - let mut bytes = Vec::with_capacity(values.len() * 8); - for &value in values { - let value_bytes = match endianness { - VectorEndianness::BigEndian => value.to_be_bytes(), - VectorEndianness::LittleEndian => value.to_le_bytes(), - }; - bytes.extend_from_slice(&value_bytes); - } - bytes -} - -/// Read f64 values from bytes in a specific endianness -fn read_f64_values(bytes: &[u8], endianness: VectorEndianness) -> Vec { - let mut values = Vec::with_capacity(bytes.len() / 8); - for chunk in bytes.chunks_exact(8) { - let value = match endianness { - VectorEndianness::BigEndian => f64::from_be_bytes(chunk.try_into().unwrap()), - VectorEndianness::LittleEndian => f64::from_le_bytes(chunk.try_into().unwrap()), - }; - values.push(value); - } - values -} - -/// Create old-format vector properties (HashMap-based) -fn create_old_properties( - label: &str, - is_deleted: bool, - extra_props: HashMap, -) -> Vec { - let mut props = HashMap::new(); - props.insert("label".to_string(), Value::String(label.to_string())); - props.insert("is_deleted".to_string(), Value::Boolean(is_deleted)); - - for (k, v) in extra_props { - props.insert(k, v); - } - - bincode::serialize(&props).unwrap() -} - -/// Populate storage with test vectors in a specific endianness -fn populate_test_vectors( - storage: &mut HelixGraphStorage, - count: usize, - endianness: VectorEndianness, -) -> Result<(), GraphError> { - let mut txn = storage.graph_env.write_txn()?; - - for i in 0..count { - let id = i as u128; - let vector_data = create_test_vector_bytes( - &[i as f64, (i + 1) as f64, (i + 2) as f64], - endianness, - ); - - storage - .vectors - .vectors_db - .put(&mut txn, &id.to_be_bytes(), &vector_data)?; - } - - txn.commit()?; - Ok(()) -} - -/// Populate storage with old-format properties -fn populate_old_properties( - storage: &mut HelixGraphStorage, - count: usize, -) -> Result<(), GraphError> { - let mut txn = storage.graph_env.write_txn()?; - - for i in 0..count { - let id = i as u128; - let mut extra_props = HashMap::new(); - extra_props.insert("test_prop".to_string(), Value::F64(i as f64)); - - let property_bytes = - create_old_properties(&format!("label_{}", i), i % 2 == 0, extra_props); - - storage - .vectors - .vector_properties_db - .put(&mut txn, &id, &property_bytes)?; - } - - txn.commit()?; - Ok(()) -} - -/// Set storage metadata to a specific state -#[allow(dead_code)] -fn set_metadata( - storage: &mut HelixGraphStorage, - metadata: StorageMetadata, -) -> Result<(), GraphError> { - let mut txn = storage.graph_env.write_txn()?; - metadata.save(&mut txn, &storage.metadata_db)?; - txn.commit()?; - Ok(()) -} - -/// Read all vectors from storage and return as f64 values -fn read_all_vectors( - storage: &HelixGraphStorage, - endianness: VectorEndianness, -) -> Result>, GraphError> { - let txn = storage.graph_env.read_txn()?; - let mut all_vectors = Vec::new(); - - for kv in storage.vectors.vectors_db.iter(&txn)? { - let (_, value) = kv?; - let values = read_f64_values(value, endianness); - all_vectors.push(values); - } - - Ok(all_vectors) -} - -/// Clear all metadata from storage (simulates PreMetadata state) -fn clear_metadata(storage: &mut HelixGraphStorage) -> Result<(), GraphError> { - let mut txn = storage.graph_env.write_txn()?; - storage.metadata_db.clear(&mut txn)?; - txn.commit()?; - Ok(()) -} - -// ============================================================================ -// Unit Tests: Endianness Conversion -// ============================================================================ - -#[test] -fn test_convert_vector_endianness_empty_input() { - let arena = bumpalo::Bump::new(); - let result = convert_vector_endianness(&[], VectorEndianness::BigEndian, &arena); - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), &[] as &[u8]); -} - -#[test] -fn test_convert_vector_endianness_single_f64() { - let arena = bumpalo::Bump::new(); - let value: f64 = 3.14159; - let big_endian_bytes = value.to_be_bytes(); - - let result = - convert_vector_endianness(&big_endian_bytes, VectorEndianness::BigEndian, &arena).unwrap(); - - // Result should be in native endianness - let native_value = f64::from_ne_bytes(result.try_into().unwrap()); - assert_eq!(native_value, value); -} - -#[test] -fn test_convert_vector_endianness_multiple_f64s() { - let arena = bumpalo::Bump::new(); - let values = vec![1.0, 2.5, -3.7, 4.2, 5.9]; - let big_endian_bytes = create_test_vector_bytes(&values, VectorEndianness::BigEndian); - - let result = - convert_vector_endianness(&big_endian_bytes, VectorEndianness::BigEndian, &arena).unwrap(); - - // Read back values in native endianness - let result_values: Vec = result - .chunks_exact(8) - .map(|chunk| f64::from_ne_bytes(chunk.try_into().unwrap())) - .collect(); - - for (original, converted) in values.iter().zip(result_values.iter()) { - assert_eq!(original, converted); - } -} - -#[test] -fn test_convert_vector_endianness_invalid_length() { - let arena = bumpalo::Bump::new(); - let invalid_bytes = vec![1, 2, 3, 4, 5]; // Not a multiple of 8 - - let result = convert_vector_endianness(&invalid_bytes, VectorEndianness::BigEndian, &arena); - - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("not a multiple")); -} - -#[test] -fn test_convert_vector_endianness_roundtrip() { - let arena = bumpalo::Bump::new(); - let values = vec![1.0, 2.5, -3.7, 100.123, -999.999]; - - // Start with big endian - let big_endian_bytes = create_test_vector_bytes(&values, VectorEndianness::BigEndian); - - // Convert big -> native - let native_bytes = - convert_vector_endianness(&big_endian_bytes, VectorEndianness::BigEndian, &arena).unwrap(); - - // Read values back - let result_values: Vec = native_bytes - .chunks_exact(8) - .map(|chunk| f64::from_ne_bytes(chunk.try_into().unwrap())) - .collect(); - - for (original, converted) in values.iter().zip(result_values.iter()) { - assert_eq!(original, converted); - } -} - -#[test] -fn test_convert_vector_endianness_special_values() { - let arena = bumpalo::Bump::new(); - let special_values = vec![ - 0.0, - -0.0, - f64::INFINITY, - f64::NEG_INFINITY, - f64::MIN, - f64::MAX, - f64::EPSILON, - ]; - - let big_endian_bytes = create_test_vector_bytes(&special_values, VectorEndianness::BigEndian); - - let result = - convert_vector_endianness(&big_endian_bytes, VectorEndianness::BigEndian, &arena).unwrap(); - - let result_values: Vec = result - .chunks_exact(8) - .map(|chunk| f64::from_ne_bytes(chunk.try_into().unwrap())) - .collect(); - - for (original, converted) in special_values.iter().zip(result_values.iter()) { - // Use bit equality for special values like NaN and -0.0 - assert_eq!(original.to_bits(), converted.to_bits()); - } -} - -#[test] -fn test_convert_vector_endianness_from_little_endian() { - let arena = bumpalo::Bump::new(); - let values = vec![1.1, 2.2, 3.3]; - let little_endian_bytes = create_test_vector_bytes(&values, VectorEndianness::LittleEndian); - - let result = convert_vector_endianness( - &little_endian_bytes, - VectorEndianness::LittleEndian, - &arena, - ) - .unwrap(); - - let result_values: Vec = result - .chunks_exact(8) - .map(|chunk| f64::from_ne_bytes(chunk.try_into().unwrap())) - .collect(); - - for (original, converted) in values.iter().zip(result_values.iter()) { - assert_eq!(original, converted); - } -} - -// ============================================================================ -// Unit Tests: Property Conversion -// ============================================================================ - -#[test] -fn test_convert_old_properties_basic() { - let arena = bumpalo::Bump::new(); - let old_bytes = create_old_properties("test_label", false, HashMap::new()); - - let result = convert_old_vector_properties_to_new_format(&old_bytes, &arena); - assert!(result.is_ok()); - - // We can't directly deserialize HVector, but we can verify the conversion succeeded - let new_bytes = result.unwrap(); - assert!(!new_bytes.is_empty()); -} - -#[test] -fn test_convert_old_properties_with_deleted_flag() { - let arena = bumpalo::Bump::new(); - let old_bytes = create_old_properties("deleted_vector", true, HashMap::new()); - - let result = convert_old_vector_properties_to_new_format(&old_bytes, &arena); - assert!(result.is_ok()); - assert!(!result.unwrap().is_empty()); -} - -#[test] -fn test_convert_old_properties_with_extra_props() { - let arena = bumpalo::Bump::new(); - let mut extra = HashMap::new(); - extra.insert("name".to_string(), Value::String("test".to_string())); - extra.insert("count".to_string(), Value::F64(42.0)); - extra.insert("active".to_string(), Value::Boolean(true)); - - let old_bytes = create_old_properties("test_label", false, extra); - - let result = convert_old_vector_properties_to_new_format(&old_bytes, &arena); - assert!(result.is_ok()); - assert!(!result.unwrap().is_empty()); -} - -#[test] -fn test_convert_old_properties_empty_extra_props() { - let arena = bumpalo::Bump::new(); - let old_bytes = create_old_properties("minimal", false, HashMap::new()); - - let result = convert_old_vector_properties_to_new_format(&old_bytes, &arena); - assert!(result.is_ok()); - assert!(!result.unwrap().is_empty()); -} - -#[test] -#[should_panic(expected = "all old vectors should have label")] -fn test_convert_old_properties_missing_label() { - let arena = bumpalo::Bump::new(); - let mut props = HashMap::new(); - props.insert("is_deleted".to_string(), Value::Boolean(false)); - // Missing "label" - - let bytes = bincode::serialize(&props).unwrap(); - let _ = convert_old_vector_properties_to_new_format(&bytes, &arena); -} - -#[test] -#[should_panic(expected = "all old vectors should have deleted")] -fn test_convert_old_properties_missing_is_deleted() { - let arena = bumpalo::Bump::new(); - let mut props = HashMap::new(); - props.insert("label".to_string(), Value::String("test".to_string())); - // Missing "is_deleted" - - let bytes = bincode::serialize(&props).unwrap(); - let _ = convert_old_vector_properties_to_new_format(&bytes, &arena); -} - -#[test] -fn test_convert_old_properties_invalid_bincode() { - let arena = bumpalo::Bump::new(); - let invalid_bytes = vec![1, 2, 3, 4, 5]; // Not valid bincode - - let result = convert_old_vector_properties_to_new_format(&invalid_bytes, &arena); - assert!(result.is_err()); -} - -// ============================================================================ -// Integration Tests: Full Migration Scenarios -// ============================================================================ - -#[test] -fn test_migrate_empty_database() { - let (storage, _temp_dir) = setup_test_storage(); - - // Storage is already created with migrations run, but let's verify the state - let txn = storage.graph_env.read_txn().unwrap(); - let metadata = StorageMetadata::read(&txn, &storage.metadata_db).unwrap(); - - assert!(matches!( - metadata, - StorageMetadata::VectorNativeEndianness { .. } - )); -} - -#[test] -fn test_migrate_pre_metadata_to_native() { - let (mut storage, _temp_dir) = setup_test_storage(); - - // Clear metadata to simulate PreMetadata state - clear_metadata(&mut storage).unwrap(); - - // Populate with vectors in big-endian format (PreMetadata default) - populate_test_vectors(&mut storage, 10, VectorEndianness::BigEndian).unwrap(); - populate_old_properties(&mut storage, 10).unwrap(); - - // Run migration - let result = migrate(&mut storage); - assert!(result.is_ok()); - - // Verify metadata was updated - { - let txn = storage.graph_env.read_txn().unwrap(); - let metadata = StorageMetadata::read(&txn, &storage.metadata_db).unwrap(); - - match metadata { - StorageMetadata::VectorNativeEndianness { vector_endianness } => { - assert_eq!(vector_endianness, NATIVE_VECTOR_ENDIANNESS); - } - _ => panic!("Expected VectorNativeEndianness metadata"), - } - } // txn dropped here - - // Verify vectors are readable in native endianness - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 10); - - for (i, vector) in vectors.iter().enumerate() { - let expected = vec![i as f64, (i + 1) as f64, (i + 2) as f64]; - assert_eq!(vector, &expected); - } -} - -#[test] -fn test_migrate_single_vector() { - let (mut storage, _temp_dir) = setup_test_storage(); - - // Clear and repopulate - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 1, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 1); - assert_eq!(vectors[0], vec![0.0, 1.0, 2.0]); -} - -#[test] -fn test_migrate_exact_batch_size() { - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 1024, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 1024); - - // Verify first and last vectors - assert_eq!(vectors[0], vec![0.0, 1.0, 2.0]); - assert_eq!(vectors[1023], vec![1023.0, 1024.0, 1025.0]); -} - -#[test] -fn test_migrate_multiple_batches() { - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 2500, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 2500); - - // Verify vectors across batch boundaries - assert_eq!(vectors[0], vec![0.0, 1.0, 2.0]); - assert_eq!(vectors[1023], vec![1023.0, 1024.0, 1025.0]); - assert_eq!(vectors[1024], vec![1024.0, 1025.0, 1026.0]); - assert_eq!(vectors[2499], vec![2499.0, 2500.0, 2501.0]); -} - -#[test] -fn test_migrate_already_native_endianness() { - let (mut storage, _temp_dir) = setup_test_storage(); - - // Add vectors already in native endianness - populate_test_vectors(&mut storage, 10, NATIVE_VECTOR_ENDIANNESS).unwrap(); - - // Migration should be a no-op (already done during setup_test_storage) - let result = migrate(&mut storage); - assert!(result.is_ok()); - - // Vectors should remain unchanged - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 10); -} - -#[test] -fn test_migrate_idempotency() { - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 100, VectorEndianness::BigEndian).unwrap(); - - // Run migration multiple times - migrate(&mut storage).unwrap(); - let vectors_after_first = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - - migrate(&mut storage).unwrap(); - let vectors_after_second = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - - migrate(&mut storage).unwrap(); - let vectors_after_third = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - - // All should be identical - assert_eq!(vectors_after_first, vectors_after_second); - assert_eq!(vectors_after_second, vectors_after_third); -} - -#[test] -fn test_migrate_with_properties() { - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 50, VectorEndianness::BigEndian).unwrap(); - populate_old_properties(&mut storage, 50).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - // Verify both vectors and properties were migrated - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 50); - - // Check properties count - let txn = storage.graph_env.read_txn().unwrap(); - let prop_count = storage.vectors.vector_properties_db.len(&txn).unwrap(); - assert_eq!(prop_count, 50); -} - -#[test] -fn test_migrate_cognee_vector_string_dates_error() { - // This test reproduces a bincode I/O error that occurs when migrating - // CogneeVector data where dates were stored as RFC3339 strings instead - // of proper Date types. - // - // Old schema had: - // created_at: String (RFC3339 format via chrono::Utc::now().to_rfc3339()) - // updated_at: String (RFC3339 format) - // - // New schema expects: - // created_at: Date - // updated_at: Date - // - // This mismatch can cause bincode deserialization errors during migration. - - let (mut storage, _temp_dir) = setup_test_storage(); - - // Clear metadata to simulate PreMetadata state - clear_metadata(&mut storage).unwrap(); - - // Create old-format CogneeVector properties with dates as strings - // (matching how they were actually created in the old format) - let mut extra_props = HashMap::new(); - - // Add CogneeVector-specific fields - extra_props.insert( - "collection_name".to_string(), - Value::String("test_collection".to_string()), - ); - extra_props.insert( - "data_point_id".to_string(), - Value::String("dp_001".to_string()), - ); - extra_props.insert( - "payload".to_string(), - Value::String(r#"{"id":"123","created_at":"2024-01-01","updated_at":"2024-01-01","ontology_valid":true,"version":1,"topological_rank":0,"type":"DataPoint"}"#.to_string()), - ); - extra_props.insert( - "content".to_string(), - Value::String("Test content for CogneeVector".to_string()), - ); - - // Add dates as strings (RFC3339) - this is the problematic part - // In the old format, these were created as: - // Value::from(chrono::Utc::now().to_rfc3339()) - // which creates Value::String, not Value::Date - extra_props.insert( - "created_at".to_string(), - Value::String("2024-01-01T12:00:00.000000000Z".to_string()), - ); - extra_props.insert( - "updated_at".to_string(), - Value::String("2024-01-01T12:00:00.000000000Z".to_string()), - ); - - // Create old properties with CogneeVector label - let old_bytes = create_old_properties("CogneeVector", false, extra_props); - - // Insert into storage - { - let mut txn = storage.graph_env.write_txn().unwrap(); - let id = 123u128; - storage - .vectors - .vector_properties_db - .put(&mut txn, &id, &old_bytes) - .unwrap(); - txn.commit().unwrap(); - } - - // Verify the data was inserted - { - let txn = storage.graph_env.read_txn().unwrap(); - let stored_bytes = storage - .vectors - .vector_properties_db - .get(&txn, &123u128) - .unwrap(); - assert!(stored_bytes.is_some()); - - // Verify we can deserialize it as old format - let old_props: HashMap = bincode::deserialize(stored_bytes.unwrap()).unwrap(); - assert_eq!(old_props.get("label").unwrap(), &Value::String("CogneeVector".to_string())); - assert_eq!(old_props.get("collection_name").unwrap(), &Value::String("test_collection".to_string())); - - // Verify dates are strings, not Date types - match old_props.get("created_at").unwrap() { - Value::String(s) => assert!(s.contains("2024-01-01")), - _ => panic!("Expected created_at to be Value::String in old format"), - } - } - - // Run migration - this preserves the data as-is - let result = migrate(&mut storage); - - // Migration succeeds because it just copies the HashMap to the new format - match result { - Ok(_) => { - println!("✅ Migration succeeded (preserves old data as-is)"); - - // The real error occurs when trying to deserialize the migrated data - // This simulates what v_from_type does when querying by label - let txn = storage.graph_env.read_txn().unwrap(); - let migrated_bytes = storage - .vectors - .vector_properties_db - .get(&txn, &123u128) - .unwrap() - .unwrap(); - - println!("Migrated data exists: {} bytes", migrated_bytes.len()); - - // Try to deserialize as VectorWithoutData (what v_from_type does) - use crate::helix_engine::vector_core::vector_without_data::VectorWithoutData; - let arena2 = bumpalo::Bump::new(); - let deserialize_result = VectorWithoutData::from_bincode_bytes(&arena2, migrated_bytes, 123u128); - - match deserialize_result { - Ok(vector) => { - println!("⚠️ Deserialization succeeded!"); - println!("Vector label: {}", vector.label); - println!("This means bincode preserved the string dates in properties."); - - // Check if dates are accessible - if let Some(created_at) = vector.get_property("created_at") { - println!("created_at type: {:?}", created_at); - match created_at { - Value::String(s) => println!(" Still a string: {}", s), - Value::Date(d) => println!(" Converted to Date: {:?}", d), - _ => println!(" Other type: {:?}", created_at), - } - } - } - Err(e) => { - println!("✅ REPRODUCED THE ERROR during deserialization!"); - println!("Error: {}", e); - println!(); - println!("This error occurs in the v_from_type query path:"); - println!(" 1. Migration preserves dates as Value::String"); - println!(" 2. v_from_type calls VectorWithoutData::from_bincode_bytes"); - println!(" 3. Bincode deserialization expects specific value types"); - println!(" 4. Type mismatch causes ConversionError"); - - // Verify it's the expected error type - let error_str = e.to_string(); - assert!( - error_str.contains("deserializ") || error_str.contains("Conversion"), - "Expected deserialization/conversion error, got: {}", - e - ); - } - } - } - Err(e) => { - println!("❌ Migration failed unexpectedly: {}", e); - panic!("Migration should succeed but preserve old data"); - } - } -} - -// ============================================================================ -// Integration Tests: Batch Boundary Conditions -// ============================================================================ - -#[test] -fn test_migrate_batch_boundary_1023() { - let (mut storage, _temp_dir) = setup_test_storage(); - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 1023, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 1023); -} - -#[test] -fn test_migrate_batch_boundary_1025() { - let (mut storage, _temp_dir) = setup_test_storage(); - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 1025, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 1025); -} - -#[test] -fn test_migrate_batch_boundary_2047() { - let (mut storage, _temp_dir) = setup_test_storage(); - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 2047, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 2047); -} - -#[test] -fn test_migrate_batch_boundary_2048() { - let (mut storage, _temp_dir) = setup_test_storage(); - clear_metadata(&mut storage).unwrap(); - populate_test_vectors(&mut storage, 2048, VectorEndianness::BigEndian).unwrap(); - - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 2048); -} - -// ============================================================================ -// Property-Based Tests -// ============================================================================ - -use proptest::prelude::*; - -proptest! { - #[test] - fn proptest_endianness_conversion_preserves_values( - values in prop::collection::vec(prop::num::f64::ANY, 1..100) - ) { - let arena = bumpalo::Bump::new(); - - // Filter out NaN for equality comparison - let values: Vec = values.into_iter().filter(|v| !v.is_nan()).collect(); - if values.is_empty() { - return Ok(()); - } - - // Test both endianness conversions - for source_endianness in [VectorEndianness::BigEndian, VectorEndianness::LittleEndian] { - let source_bytes = create_test_vector_bytes(&values, source_endianness); - - let result = convert_vector_endianness(&source_bytes, source_endianness, &arena) - .expect("conversion should succeed"); - - let result_values: Vec = result - .chunks_exact(8) - .map(|chunk| f64::from_ne_bytes(chunk.try_into().unwrap())) - .collect(); - - prop_assert_eq!(values.len(), result_values.len()); - - for (original, converted) in values.iter().zip(result_values.iter()) { - prop_assert_eq!(original, converted); - } - } - } - - #[test] - fn proptest_endianness_conversion_valid_length( - byte_count in 1usize..200 - ) { - let arena = bumpalo::Bump::new(); - let bytes = vec![0u8; byte_count]; - - let result = convert_vector_endianness(&bytes, VectorEndianness::BigEndian, &arena); - - if byte_count % 8 == 0 { - prop_assert!(result.is_ok()); - } else { - prop_assert!(result.is_err()); - } - } - - #[test] - fn proptest_property_migration_preserves_data( - label in "[a-z]{1,20}", - is_deleted in any::(), - prop_count in 0usize..10 - ) { - let arena = bumpalo::Bump::new(); - let mut extra_props = HashMap::new(); - - for i in 0..prop_count { - extra_props.insert( - format!("prop_{}", i), - Value::F64(i as f64), - ); - } - - let old_bytes = create_old_properties(&label, is_deleted, extra_props); - let result = convert_old_vector_properties_to_new_format(&old_bytes, &arena) - .expect("property conversion should succeed"); - - // Verify conversion succeeded by checking result is not empty - prop_assert!(!result.is_empty()); - } -} - -// ============================================================================ -// Error Handling Tests -// ============================================================================ - -#[test] -fn test_error_invalid_vector_data_length() { - let arena = bumpalo::Bump::new(); - let invalid_bytes = vec![1, 2, 3, 4, 5, 6, 7]; // 7 bytes, not multiple of 8 - - let result = convert_vector_endianness(&invalid_bytes, VectorEndianness::BigEndian, &arena); - - assert!(result.is_err()); - match result { - Err(GraphError::New(msg)) => { - assert!(msg.contains("not a multiple")); - } - _ => panic!("Expected GraphError::New with length error"), - } -} - -#[test] -fn test_error_corrupted_property_data() { - let arena = bumpalo::Bump::new(); - let corrupted = vec![255u8; 100]; // Random bytes, not valid bincode - - let result = convert_old_vector_properties_to_new_format(&corrupted, &arena); - assert!(result.is_err()); -} - -#[test] -#[ignore] -fn test_date_bincode_serialization() { - // Test that Date values serialize/deserialize correctly with bincode - use crate::protocol::date::Date; - - // Create a Date and wrap it in Value::Date - let date = Date::new(&Value::I64(1609459200)).unwrap(); // 2021-01-01 - let value = Value::Date(date); - - // Serialize with bincode - let serialized = bincode::serialize(&value).unwrap(); - println!("\nValue::Date serialized to {} bytes", serialized.len()); - println!("Format: [variant=12] [i64 timestamp]"); - println!("Bytes: {:?}", serialized); - - // Deserialize - let deserialized: Value = bincode::deserialize(&serialized).unwrap(); - - // Verify it's a Date variant with correct value - match deserialized { - Value::Date(d) => { - assert_eq!(d.timestamp(), 1609459200); - assert!(d.to_rfc3339().starts_with("2021-01-01")); - println!("✅ Bincode serialization works correctly!"); - println!(" Date: {}", d.to_rfc3339()); - } - _ => panic!("Expected Value::Date variant"), - } - - // Also test JSON serialization still works - let json = sonic_rs::to_string(&value).unwrap(); - let from_json: Value = sonic_rs::from_str(&json).unwrap(); - // JSON deserializes dates as strings, which is expected - assert!(matches!(from_json, Value::String(_))); - println!("✅ JSON serialization also works (deserializes as Value::String as expected)!"); -} - -#[test] -fn test_error_handling_graceful_failure() { - // Test that errors don't corrupt the database - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - - // Add valid data - populate_test_vectors(&mut storage, 10, VectorEndianness::BigEndian).unwrap(); - - // Now add invalid data manually - { - let mut txn = storage.graph_env.write_txn().unwrap(); - let bad_id = 9999u128; - let bad_data = vec![1, 2, 3]; // Invalid length - - storage - .vectors - .vectors_db - .put(&mut txn, &bad_id.to_be_bytes(), &bad_data) - .unwrap(); - - txn.commit().unwrap(); - } - - // Migration should fail on invalid data - let result = migrate(&mut storage); - assert!(result.is_err()); - - // But the 10 valid vectors should still be there - let txn = storage.graph_env.read_txn().unwrap(); - let count = storage.vectors.vectors_db.len(&txn).unwrap(); - assert_eq!(count, 11); // 10 valid + 1 invalid -} - -// ============================================================================ -// Performance Tests -// ============================================================================ - -#[test] -#[ignore] // Run with: cargo test --release -- --ignored --nocapture -fn test_performance_large_dataset() { - use std::time::Instant; - - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - - // Create 100K vectors - println!("Populating 100K vectors..."); - let start = Instant::now(); - populate_test_vectors(&mut storage, 100_000, VectorEndianness::BigEndian).unwrap(); - println!("Population took: {:?}", start.elapsed()); - - // Migrate - println!("Running migration..."); - let start = Instant::now(); - let result = migrate(&mut storage); - let duration = start.elapsed(); - - assert!(result.is_ok()); - println!("Migration of 100K vectors took: {:?}", duration); - println!("Average: {:?} per vector", duration / 100_000); - - // Verify a sample - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 100_000); - assert_eq!(vectors[0], vec![0.0, 1.0, 2.0]); - assert_eq!(vectors[50_000], vec![50_000.0, 50_001.0, 50_002.0]); - assert_eq!(vectors[99_999], vec![99_999.0, 100_000.0, 100_001.0]); -} - -#[test] -#[ignore] -fn test_performance_property_migration() { - use std::time::Instant; - - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - - println!("Populating 50K properties..."); - populate_old_properties(&mut storage, 50_000).unwrap(); - - println!("Running property migration..."); - let start = Instant::now(); - let result = convert_all_vector_properties(&mut storage); - let duration = start.elapsed(); - - assert!(result.is_ok()); - println!("Property migration of 50K items took: {:?}", duration); - println!("Average: {:?} per property", duration / 50_000); -} - -#[test] -fn test_memory_efficiency_batch_processing() { - // This test verifies that batch processing doesn't cause memory issues - let (mut storage, _temp_dir) = setup_test_storage(); - - clear_metadata(&mut storage).unwrap(); - - // Create 5000 vectors (multiple batches) - populate_test_vectors(&mut storage, 5000, VectorEndianness::BigEndian).unwrap(); - - // Migration should complete without OOM - let result = migrate(&mut storage); - assert!(result.is_ok()); - - let vectors = read_all_vectors(&storage, NATIVE_VECTOR_ENDIANNESS).unwrap(); - assert_eq!(vectors.len(), 5000); -} diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs index d2091cb64..73830d0aa 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs @@ -12,7 +12,6 @@ /// - Multiple inserts at same level could create invalid graph topology /// - Delete during search might return inconsistent results /// - LMDB transaction model provides MVCC but needs validation - use bumpalo::Bump; use heed3::{Env, EnvOpenOptions, RoTxn, RwTxn}; use rand::Rng; @@ -20,11 +19,7 @@ use std::sync::{Arc, Barrier}; use std::thread; use tempfile::TempDir; -use crate::helix_engine::vector_core::{ - hnsw::HNSW, - vector::HVector, - vector_core::{HNSWConfig, VectorCore}, -}; +use crate::helix_engine::vector_core::{HNSWConfig, HVector, VectorCore}; type Filter = fn(&HVector, &RoTxn) -> bool; @@ -50,12 +45,17 @@ fn setup_concurrent_env() -> (TempDir, Env) { /// Generate a random vector of given dimensionality fn random_vector(dim: usize) -> Vec { - (0..dim).map(|_| rand::rng().random_range(0.0..1.0)).collect() + (0..dim) + .map(|_| rand::rng().random_range(0.0..1.0)) + .collect() } /// Open existing VectorCore databases (for concurrent access) /// Note: create_database opens existing database if it exists -fn open_vector_core(env: &Env, txn: &mut RwTxn) -> Result { +fn open_vector_core( + env: &Env, + txn: &mut RwTxn, +) -> Result { VectorCore::new(env, txn, HNSWConfig::new(None, None, None)) } @@ -100,7 +100,8 @@ fn test_concurrent_inserts_single_label() { // Open the existing databases and insert let index = open_vector_core(&env, &mut wtxn).unwrap(); - index.insert::(&mut wtxn, "concurrent_test", data, None, &arena) + index + .insert(&mut wtxn, "concurrent_test", data, None, &arena) .expect("Insert should succeed"); wtxn.commit().expect("Commit should succeed"); } @@ -118,10 +119,10 @@ fn test_concurrent_inserts_single_label() { let index = open_vector_core(&env, &mut wtxn).unwrap(); wtxn.commit().unwrap(); let rtxn = env.read_txn().unwrap(); - let count = index.num_inserted_vectors(&rtxn).unwrap(); + let count = index.num_inserted_vectors(); // Note: count includes entry point (+1), so actual vectors inserted = count - 1 - let expected_inserted = (num_threads * vectors_per_thread) as u64; + let expected_inserted = num_threads * vectors_per_thread; assert!( count == expected_inserted || count == expected_inserted + 1, "Expected {} or {} vectors (with entry point), found {}", @@ -133,7 +134,7 @@ fn test_concurrent_inserts_single_label() { // Additional consistency check: Verify we can perform searches (entry point exists implicitly) let arena = Bump::new(); let query = [0.5; 128]; - let search_result = index.search::(&rtxn, &query, 10, "concurrent_test", None, false, &arena); + let search_result = index.search(&rtxn, &query, 10, "concurrent_test", false, &arena); assert!( search_result.is_ok(), "Should be able to search after concurrent inserts (entry point exists)" @@ -161,7 +162,9 @@ fn test_concurrent_searches_during_inserts() { for _ in 0..50 { let vector = random_vector(128); let data = arena.alloc_slice_copy(&vector); - index.insert::(&mut txn, "search_test", data, None, &arena).unwrap(); + index + .insert(&mut txn, "search_test", data, None, &arena) + .unwrap(); } txn.commit().unwrap(); } @@ -188,22 +191,14 @@ fn test_concurrent_searches_during_inserts() { // Perform many searches // Open databases once per thread let mut wtxn_init = env.write_txn().unwrap(); - let index = open_vector_core(&env, &mut wtxn_init).unwrap(); + let index: VectorCore = open_vector_core(&env, &mut wtxn_init).unwrap(); wtxn_init.commit().unwrap(); for _ in 0..50 { let rtxn = env.read_txn().unwrap(); let arena = Bump::new(); - match index.search::( - &rtxn, - &query[..], - 10, - "search_test", - None, - false, - &arena, - ) { + match index.search(&rtxn, &query[..], 10, "search_test", false, &arena) { Ok(results) => { total_searches += 1; total_results += results.len(); @@ -251,7 +246,8 @@ fn test_concurrent_searches_during_inserts() { let data = arena.alloc_slice_copy(&vector); let index = open_vector_core(&env, &mut wtxn).unwrap(); - index.insert::(&mut wtxn, "search_test", data, None, &arena) + index + .insert(&mut wtxn, "search_test", data, None, &arena) .expect("Insert should succeed"); wtxn.commit().expect("Commit should succeed"); @@ -270,7 +266,7 @@ fn test_concurrent_searches_during_inserts() { let index = open_vector_core(&env, &mut wtxn).unwrap(); wtxn.commit().unwrap(); let rtxn = env.read_txn().unwrap(); - let final_count = index.num_inserted_vectors(&rtxn).unwrap(); + let final_count = index.num_inserted_vectors(); assert!( final_count >= 50, @@ -281,9 +277,12 @@ fn test_concurrent_searches_during_inserts() { // Verify we can still search successfully let arena = Bump::new(); let results = index - .search::(&rtxn, &query[..], 10, "search_test", None, false, &arena) + .search(&rtxn, &query[..], 10, "search_test", false, &arena) .unwrap(); - assert!(!results.is_empty(), "Should find results after concurrent operations"); + assert!( + !results.is_empty(), + "Should find results after concurrent operations" + ); } #[test] @@ -324,9 +323,7 @@ fn test_concurrent_inserts_multiple_labels() { let vector = random_vector(64); let data = arena.alloc_slice_copy(&vector); - index - .insert::(&mut wtxn, &label, data, None, &arena) - .unwrap(); + index.insert(&mut wtxn, &label, data, None, &arena).unwrap(); wtxn.commit().unwrap(); if i % 10 == 0 { @@ -353,7 +350,7 @@ fn test_concurrent_inserts_multiple_labels() { // Verify we can search for each label (entry point exists implicitly) let query = [0.5; 64]; - let search_result = index.search::(&rtxn, &query, 5, &label, None, false, &arena); + let search_result = index.search(&rtxn, &query, 5, &label, false, &arena); assert!( search_result.is_ok(), "Should be able to search label {}", @@ -361,8 +358,8 @@ fn test_concurrent_inserts_multiple_labels() { ); } - let total_count = index.num_inserted_vectors(&rtxn).unwrap(); - let expected_total = (num_labels * vectors_per_label) as u64; + let total_count = index.num_inserted_vectors(); + let expected_total = num_labels * vectors_per_label; assert!( total_count == expected_total || total_count == expected_total + 1, "Expected {} or {} vectors (with entry point), found {}", @@ -412,7 +409,7 @@ fn test_entry_point_consistency() { let data = arena.alloc_slice_copy(&vector); index - .insert::(&mut wtxn, "entry_test", data, None, &arena) + .insert(&mut wtxn, "entry_test", data, None, &arena) .unwrap(); wtxn.commit().unwrap(); } @@ -433,17 +430,26 @@ fn test_entry_point_consistency() { // If we can successfully search, entry point must be valid let query = [0.5; 32]; - let search_result = index.search::(&rtxn, &query, 10, "entry_test", None, false, &arena); - assert!(search_result.is_ok(), "Entry point should exist and be valid"); + let search_result = index.search(&rtxn, &query, 10, "entry_test", false, &arena); + assert!( + search_result.is_ok(), + "Entry point should exist and be valid" + ); let results = search_result.unwrap(); - assert!(!results.is_empty(), "Should return results if entry point is valid"); + assert!( + !results.is_empty(), + "Should return results if entry point is valid" + ); // Verify results have valid properties for result in results.iter() { assert!(result.id > 0, "Result ID should be valid"); assert!(!result.deleted, "Results should not be deleted"); - assert!(!result.data.is_empty(), "Results should have data"); + assert!( + !result.data_borrowed().is_empty(), + "Results should have data" + ); } } @@ -484,7 +490,7 @@ fn test_graph_connectivity_after_concurrent_inserts() { let data = arena.alloc_slice_copy(&vector); index - .insert::(&mut wtxn, "connectivity_test", data, None, &arena) + .insert(&mut wtxn, "connectivity_test", data, None, &arena) .unwrap(); wtxn.commit().unwrap(); } @@ -507,15 +513,7 @@ fn test_graph_connectivity_after_concurrent_inserts() { for i in 0..10 { let query = random_vector(64); let results = index - .search::( - &rtxn, - &query, - 10, - "connectivity_test", - None, - false, - &arena, - ) + .search(&rtxn, &query, 10, "connectivity_test", false, &arena) .unwrap(); assert!( @@ -553,7 +551,9 @@ fn test_transaction_isolation() { for _ in 0..initial_count { let vector = random_vector(32); let data = arena.alloc_slice_copy(&vector); - index.insert::(&mut txn, "isolation_test", data, None, &arena).unwrap(); + index + .insert(&mut txn, "isolation_test", data, None, &arena) + .unwrap(); } txn.commit().unwrap(); } @@ -564,7 +564,7 @@ fn test_transaction_isolation() { wtxn_open.commit().unwrap(); let rtxn = env.read_txn().unwrap(); - let count_before = index.num_inserted_vectors(&rtxn).unwrap(); + let count_before = index.num_inserted_vectors(); // Entry point may be included in count (+1) assert!( @@ -585,7 +585,9 @@ fn test_transaction_isolation() { let vector = random_vector(32); let data = arena.alloc_slice_copy(&vector); - index.insert::(&mut wtxn, "isolation_test", data, None, &arena).unwrap(); + index + .insert(&mut wtxn, "isolation_test", data, None, &arena) + .unwrap(); wtxn.commit().unwrap(); } }); @@ -593,7 +595,7 @@ fn test_transaction_isolation() { handle.join().unwrap(); // Original read transaction should still see the same count (snapshot isolation) - let count_after = index.num_inserted_vectors(&rtxn).unwrap(); + let count_after = index.num_inserted_vectors(); assert_eq!( count_after, count_before, "Read transaction should see consistent snapshot" @@ -606,13 +608,14 @@ fn test_transaction_isolation() { let index_new = open_vector_core(&env, &mut wtxn_new).unwrap(); wtxn_new.commit().unwrap(); - let rtxn_new = env.read_txn().unwrap(); - let count_new = index_new.num_inserted_vectors(&rtxn_new).unwrap(); + let count_new = index_new.num_inserted_vectors(); // Entry point may be included in counts (+1) let expected_new = initial_count + 20; assert!( - count_new == expected_new || count_new == expected_new + 1 || count_new == initial_count + 20 + 1, + count_new == expected_new + || count_new == expected_new + 1 + || count_new == initial_count + 20 + 1, "Expected around {} vectors, got {}", expected_new, count_new diff --git a/helix-db/src/helix_engine/tests/hnsw_tests.rs b/helix-db/src/helix_engine/tests/hnsw_tests.rs index 78f4a48c4..a0668eb53 100644 --- a/helix-db/src/helix_engine/tests/hnsw_tests.rs +++ b/helix-db/src/helix_engine/tests/hnsw_tests.rs @@ -3,11 +3,7 @@ use heed3::{Env, EnvOpenOptions, RoTxn}; use rand::Rng; use tempfile::TempDir; -use crate::helix_engine::vector_core::{ - hnsw::HNSW, - vector::HVector, - vector_core::{HNSWConfig, VectorCore}, -}; +use crate::helix_engine::vector_core::{HNSWConfig, HVector, VectorCore}; type Filter = fn(&HVector, &RoTxn) -> bool; @@ -36,13 +32,12 @@ fn test_hnsw_insert_and_count() { let arena = Bump::new(); let data = arena.alloc_slice_copy(&vector); let _ = index - .insert::(&mut txn, "vector", data, None, &arena) + .insert(&mut txn, "vector", data, None, &arena) .unwrap(); } txn.commit().unwrap(); - let txn = env.read_txn().unwrap(); - assert!(index.num_inserted_vectors(&txn).unwrap() >= 10); + assert!(index.num_inserted_vectors() >= 10); } #[test] @@ -57,7 +52,7 @@ fn test_hnsw_search_returns_results() { let vector: Vec = (0..4).map(|_| rng.random_range(0.0..1.0)).collect(); let data = arena.alloc_slice_copy(&vector); let _ = index - .insert::(&mut txn, "vector", data, None, &arena) + .insert(&mut txn, "vector", data, None, &arena) .unwrap(); } txn.commit().unwrap(); @@ -66,7 +61,7 @@ fn test_hnsw_search_returns_results() { let txn = env.read_txn().unwrap(); let query = [0.5, 0.5, 0.5, 0.5]; let results = index - .search::(&txn, &query, 5, "vector", None, false, &arena) + .search(&txn, &query, 5, "vector", false, &arena) .unwrap(); assert!(!results.is_empty()); } diff --git a/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs index cd237235b..ce6343e81 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs @@ -25,7 +25,7 @@ use crate::{ traversal_value::TraversalValue, }, types::GraphError, - vector_core::vector::HVector, + vector_core::HVector, }, props, }; @@ -170,7 +170,8 @@ fn test_drop_node() { let edges = G::new(&storage, &txn, &arena) .n_from_id(&node2_id) .in_e("knows") - .collect::, _>>().unwrap(); + .collect::, _>>() + .unwrap(); println!("edges: {:?}", edges); assert!(edges.is_empty()); } @@ -390,24 +391,22 @@ fn test_vector_deletion_in_existing_graph() { let mut vector_ids = Vec::new(); for _ in 0..10 { let id = match G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 1.0, 1.0, 1.0], "vector", None) + .insert_v(&[1.0, 1.0, 1.0, 1.0], "vector", None) .collect_to_obj() .unwrap() { TraversalValue::Vector(vector) => vector.id, - TraversalValue::VectorNodeWithoutVectorData(vector) => *vector.id(), other => panic!("unexpected value: {other:?}"), }; vector_ids.push(id); } let target_vector_id = match G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 1.0, 1.0, 1.0], "vector", None) + .insert_v(&[1.0, 1.0, 1.0, 1.0], "vector", None) .collect_to_obj() .unwrap() { TraversalValue::Vector(vector) => vector.id, - TraversalValue::VectorNodeWithoutVectorData(vector) => *vector.id(), other => panic!("unexpected value: {other:?}"), }; @@ -443,10 +442,8 @@ fn test_vector_deletion_in_existing_graph() { .n_from_id(&node_id) .out_vec("knows", false) .filter_ref(|val, _| match val { - Ok(TraversalValue::Vector(vector)) => Ok(*vector.id() == target_vector_id), - Ok(TraversalValue::VectorNodeWithoutVectorData(vector)) => { - Ok(*vector.id() == target_vector_id) - } + Ok(TraversalValue::Vector(vector)) => Ok(vector.id == target_vector_id), + Ok(_) => Ok(false), Err(err) => Err(GraphError::from(err.to_string())), }) diff --git a/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs index 426154d11..dbad15e83 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs @@ -22,7 +22,7 @@ use crate::{ traversal_value::TraversalValue, }, types::GraphError, - vector_core::vector::HVector, + vector_core::HVector, }, props, protocol::value::Value, @@ -58,23 +58,27 @@ fn test_add_edge_creates_relationship() { let source_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let target_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let edge = G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, source_id, target_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let fetched = G::new(&storage, &txn, &arena) .e_from_id(&edge.id()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(fetched.len(), 1); assert_eq!(edge_id(&fetched[0]), edge.id()); } @@ -87,15 +91,18 @@ fn test_out_e_returns_edge() { let source_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let target_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, source_id, target_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); @@ -103,7 +110,8 @@ fn test_out_e_returns_edge() { let edges = G::new(&storage, &txn, &arena) .n_from_id(&source_id) .out_e("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(edges.len(), 1); assert_eq!(edges[0].id(), edge_id(&edges[0])); } @@ -116,15 +124,18 @@ fn test_in_e_returns_edge() { let source_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let target_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, source_id, target_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); @@ -132,7 +143,8 @@ fn test_in_e_returns_edge() { let edges = G::new(&storage, &txn, &arena) .n_from_id(&target_id) .in_e("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(edges.len(), 1); assert_eq!(edge_id(&edges[0]), edges[0].id()); } @@ -145,15 +157,18 @@ fn test_out_node_returns_neighbor() { let source_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let neighbor_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, source_id, neighbor_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); @@ -161,7 +176,8 @@ fn test_out_node_returns_neighbor() { let neighbors = G::new(&storage, &txn, &arena) .n_from_id(&source_id) .out_node("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(neighbors.len(), 1); assert_eq!(neighbors[0].id(), neighbor_id); } @@ -174,11 +190,13 @@ fn test_edge_properties_can_be_read() { let source_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let target_id = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -188,14 +206,16 @@ fn test_edge_properties_can_be_read() { target_id, false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let edge = G::new(&storage, &txn, &arena) .e_from_type("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(edge.len(), 1); if let TraversalValue::Edge(edge) = &edge[0] { match edge.properties.as_ref().unwrap().get("since").unwrap() { @@ -216,19 +236,21 @@ fn test_vector_edges_roundtrip() { let node_id = G::new_mut(&storage, &arena, &mut txn) .add_n("doc", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let vector_id = match G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 0.0, 0.0], "embedding", None) - .collect_to_obj().unwrap() + .insert_v(&[1.0, 0.0, 0.0], "embedding", None) + .collect_to_obj() + .unwrap() { TraversalValue::Vector(vector) => vector.id, - TraversalValue::VectorNodeWithoutVectorData(vector) => *vector.id(), other => panic!("unexpected traversal value: {other:?}"), }; G::new_mut(&storage, &arena, &mut txn) .add_edge("has_vector", None, node_id, vector_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); @@ -236,11 +258,11 @@ fn test_vector_edges_roundtrip() { let vectors = G::new(&storage, &txn, &arena) .n_from_id(&node_id) .out_vec("has_vector", true) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(vectors.len(), 1); match &vectors[0] { - TraversalValue::Vector(vec) => assert_eq!(*vec.id(), vector_id), - TraversalValue::VectorNodeWithoutVectorData(vec) => assert_eq!(*vec.id(), vector_id), + TraversalValue::Vector(vec) => assert_eq!(vec.id, vector_id), other => panic!("unexpected traversal value: {other:?}"), } } diff --git a/helix-db/src/helix_engine/tests/traversal_tests/util_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/util_tests.rs index 737e9cbbd..67030994a 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/util_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/util_tests.rs @@ -1,30 +1,24 @@ -use std::sync::Arc; use super::test_utils::props_option; +use std::sync::Arc; use crate::{ helix_engine::{ storage_core::HelixGraphStorage, - traversal_core::{ - ops::{ - g::G, - out::{out::OutAdapter, out_e::OutEdgesAdapter}, - source::{ - add_e::AddEAdapter, - add_n::AddNAdapter, - n_from_type::NFromTypeAdapter, - }, - util::{dedup::DedupAdapter, order::OrderByAdapter}, - vectors::{insert::InsertVAdapter, search::SearchVAdapter}, - }, + traversal_core::ops::{ + g::G, + out::{out::OutAdapter, out_e::OutEdgesAdapter}, + source::{add_e::AddEAdapter, add_n::AddNAdapter, n_from_type::NFromTypeAdapter}, + util::{dedup::DedupAdapter, order::OrderByAdapter}, + vectors::{insert::InsertVAdapter, search::SearchVAdapter}, }, - vector_core::vector::HVector, + vector_core::HVector, }, props, }; +use bumpalo::Bump; use heed3::RoTxn; use tempfile::TempDir; -use bumpalo::Bump; fn setup_test_db() -> (TempDir, Arc) { let temp_dir = TempDir::new().unwrap(); let db_path = temp_dir.path().to_str().unwrap(); @@ -45,15 +39,18 @@ fn test_order_node_by_asc() { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 20 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 10 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -61,7 +58,8 @@ fn test_order_node_by_asc() { let traversal = G::new(&storage, &txn, &arena) .n_from_type("person") .order_by_asc("age") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 3); assert_eq!(traversal[0].id(), node3.id()); @@ -77,15 +75,18 @@ fn test_order_node_by_desc() { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 20 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 10 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -93,7 +94,8 @@ fn test_order_node_by_desc() { let traversal = G::new(&storage, &txn, &arena) .n_from_type("person") .order_by_desc("age") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 3); assert_eq!(traversal[0].id(), node.id()); @@ -109,15 +111,18 @@ fn test_order_edge_by_asc() { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 20 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 10 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let edge = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -127,7 +132,8 @@ fn test_order_edge_by_asc() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let edge2 = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -137,7 +143,8 @@ fn test_order_edge_by_asc() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -146,7 +153,8 @@ fn test_order_edge_by_asc() { .n_from_type("person") .out_e("knows") .order_by_asc("since") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 2); assert_eq!(traversal[0].id(), edge.id()); @@ -161,15 +169,18 @@ fn test_order_edge_by_desc() { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 20 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 10 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let edge = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -179,7 +190,8 @@ fn test_order_edge_by_desc() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let edge2 = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -189,7 +201,8 @@ fn test_order_edge_by_desc() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -198,7 +211,8 @@ fn test_order_edge_by_desc() { .n_from_type("person") .out_e("knows") .order_by_desc("since") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 2); assert_eq!(traversal[0].id(), edge2.id()); @@ -213,16 +227,31 @@ fn test_order_vector_by_asc() { type FnTy = fn(&HVector, &RoTxn) -> bool; let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 30 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 30 }), + ) + .collect_to_obj() + .unwrap(); let vector2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 20 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 20 }), + ) + .collect_to_obj() + .unwrap(); let vector3 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 10 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 10 }), + ) + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -230,7 +259,8 @@ fn test_order_vector_by_asc() { let traversal = G::new(&storage, &txn, &arena) .search_v::(&[1.0, 2.0, 3.0], 10, "vector", None) .order_by_asc("age") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 3); assert_eq!(traversal[0].id(), vector3.id()); @@ -246,16 +276,31 @@ fn test_order_vector_by_desc() { type FnTy = fn(&HVector, &RoTxn) -> bool; let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 30 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 30 }), + ) + .collect_to_obj() + .unwrap(); let vector2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 20 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 20 }), + ) + .collect_to_obj() + .unwrap(); let vector3 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "vector", props_option(&arena, props! { "age" => 10 })) - .collect_to_obj().unwrap(); + .insert_v( + &[1.0, 2.0, 3.0], + "vector", + props_option(&arena, props! { "age" => 10 }), + ) + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -263,7 +308,8 @@ fn test_order_vector_by_desc() { let traversal = G::new(&storage, &txn, &arena) .search_v::(&[1.0, 2.0, 3.0], 10, "vector", None) .order_by_desc("age") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 3); assert_eq!(traversal[0].id(), vector.id()); @@ -279,15 +325,18 @@ fn test_dedup() { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 20 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 10 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _edge = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -297,7 +346,8 @@ fn test_dedup() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _edge2 = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -307,7 +357,8 @@ fn test_dedup() { node2.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -315,7 +366,8 @@ fn test_dedup() { let traversal = G::new(&storage, &txn, &arena) .n_from_type("person") .out_node("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 2); @@ -323,7 +375,8 @@ fn test_dedup() { .n_from_type("person") .out_node("knows") .dedup() - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), node2.id()); diff --git a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs index ed49fdacb..a572dc4d3 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs @@ -13,7 +13,8 @@ use crate::{ out::{out::OutAdapter, out_e::OutEdgesAdapter}, source::{ add_e::AddEAdapter, add_n::AddNAdapter, e_from_type::EFromTypeAdapter, - n_from_id::NFromIdAdapter, v_from_id::VFromIdAdapter, v_from_type::VFromTypeAdapter, + n_from_id::NFromIdAdapter, v_from_id::VFromIdAdapter, + v_from_type::VFromTypeAdapter, }, util::drop::Drop, vectors::{ @@ -22,7 +23,7 @@ use crate::{ }, }, types::GraphError, - vector_core::vector::HVector, + vector_core::HVector, }, utils::properties::ImmutablePropertiesMap, }; @@ -48,7 +49,7 @@ fn test_insert_and_fetch_vector() { let mut txn = storage.graph_env.write_txn().unwrap(); let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[0.1, 0.2, 0.3], "embedding", None) + .insert_v(&[0.1, 0.2, 0.3], "embedding", None) .collect_to_obj() .unwrap(); txn.commit().unwrap(); @@ -81,7 +82,7 @@ fn test_vector_edges_from_and_to_node() { .unwrap()[0] .id(); let vector_id = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 0.0, 0.0], "embedding", None) + .insert_v(&[1.0, 0.0, 0.0], "embedding", None) .collect_to_obj() .unwrap() .id(); @@ -122,7 +123,7 @@ fn test_brute_force_vector_search_orders_by_distance() { let mut vector_ids = Vec::new(); for vector in vectors { let vec_id = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&vector, "vector", None) + .insert_v(&vector, "vector", None) .collect_to_obj() .unwrap() .id(); @@ -159,7 +160,7 @@ fn test_drop_vector_removes_edges() { .unwrap()[0] .id(); let vector_id = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[0.5, 0.5, 0.5], "vector", None) + .insert_v(&[0.5, 0.5, 0.5], "vector", None) .collect_to_obj() .unwrap() .id(); @@ -210,7 +211,7 @@ fn test_v_from_type_basic_with_vector_data() { // Insert a vector with label "test_label" let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "test_label", None) + .insert_v(&[1.0, 2.0, 3.0], "test_label", None) .collect_to_obj() .unwrap(); let vector_id = vector.id(); @@ -228,9 +229,11 @@ fn test_v_from_type_basic_with_vector_data() { assert_eq!(results[0].id(), vector_id); // Verify it's a full HVector with data - if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) = &results[0] { - assert_eq!(v.data.len(), 3); - assert_eq!(v.data[0], 1.0); + if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) = + &results[0] + { + assert_eq!(v.data_borrowed().len(), 3); + assert_eq!(v.data_borrowed()[0], 1.0); } else { panic!("Expected TraversalValue::Vector"); } @@ -244,7 +247,7 @@ fn test_v_from_type_without_vector_data() { // Insert a vector with label "no_data_label" let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[4.0, 5.0, 6.0], "no_data_label", None) + .insert_v(&[4.0, 5.0, 6.0], "no_data_label", None) .collect_to_obj() .unwrap(); let vector_id = vector.id(); @@ -263,9 +266,10 @@ fn test_v_from_type_without_vector_data() { // Verify it's a VectorWithoutData match &results[0] { - crate::helix_engine::traversal_core::traversal_value::TraversalValue::VectorNodeWithoutVectorData(v) => { - assert_eq!(*v.id(), vector_id); - assert_eq!(v.label(), "no_data_label"); + crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) => { + assert_eq!(v.id, vector_id); + assert_eq!(v.label, "no_data_label"); + assert!(v.data.is_none()); } _ => panic!("Expected TraversalValue::VectorNodeWithoutVectorData"), } @@ -279,15 +283,15 @@ fn test_v_from_type_multiple_same_label() { // Insert multiple vectors with the same label let v1 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "shared_label", None) + .insert_v(&[1.0, 2.0, 3.0], "shared_label", None) .collect_to_obj() .unwrap(); let v2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[4.0, 5.0, 6.0], "shared_label", None) + .insert_v(&[4.0, 5.0, 6.0], "shared_label", None) .collect_to_obj() .unwrap(); let v3 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[7.0, 8.0, 9.0], "shared_label", None) + .insert_v(&[7.0, 8.0, 9.0], "shared_label", None) .collect_to_obj() .unwrap(); @@ -319,15 +323,15 @@ fn test_v_from_type_multiple_different_labels() { // Insert vectors with different labels let v1 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "label_a", None) + .insert_v(&[1.0, 2.0, 3.0], "label_a", None) .collect_to_obj() .unwrap(); let _v2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[4.0, 5.0, 6.0], "label_b", None) + .insert_v(&[4.0, 5.0, 6.0], "label_b", None) .collect_to_obj() .unwrap(); let _v3 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[7.0, 8.0, 9.0], "label_c", None) + .insert_v(&[7.0, 8.0, 9.0], "label_c", None) .collect_to_obj() .unwrap(); txn.commit().unwrap(); @@ -352,7 +356,7 @@ fn test_v_from_type_nonexistent_label() { // Insert a vector with a different label let _vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "existing_label", None) + .insert_v(&[1.0, 2.0, 3.0], "existing_label", None) .collect_to_obj() .unwrap(); txn.commit().unwrap(); @@ -385,8 +389,8 @@ fn test_v_from_type_empty_database() { #[test] fn test_v_from_type_with_properties() { - use std::collections::HashMap; use crate::protocol::value::Value; + use std::collections::HashMap; let (_temp_dir, storage) = setup_test_db(); let arena = Bump::new(); @@ -398,21 +402,26 @@ fn test_v_from_type_with_properties() { properties.insert("count".to_string(), Value::I64(42)); properties.insert("score".to_string(), Value::F64(3.14)); properties.insert("active".to_string(), Value::Boolean(true)); - properties.insert("tags".to_string(), Value::Array(vec![ - Value::String("tag1".to_string()), - Value::String("tag2".to_string()), - ])); + properties.insert( + "tags".to_string(), + Value::Array(vec![ + Value::String("tag1".to_string()), + Value::String("tag2".to_string()), + ]), + ); // Convert to ImmutablePropertiesMap let props_map = ImmutablePropertiesMap::new( properties.len(), - properties.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + properties + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); // Insert vector with properties let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "prop_label", Some(props_map)) + .insert_v(&[1.0, 2.0, 3.0], "prop_label", Some(props_map)) .collect_to_obj() .unwrap(); let vector_id = vector.id(); @@ -430,9 +439,14 @@ fn test_v_from_type_with_properties() { assert_eq!(results[0].id(), vector_id); // Verify properties are preserved - if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::VectorNodeWithoutVectorData(v) = &results[0] { + if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) = + &results[0] + { let props = v.properties.as_ref().unwrap(); - assert_eq!(props.get("name"), Some(&Value::String("test_vector".to_string()))); + assert_eq!( + props.get("name"), + Some(&Value::String("test_vector".to_string())) + ); assert_eq!(props.get("count"), Some(&Value::I64(42))); assert_eq!(props.get("score"), Some(&Value::F64(3.14))); assert_eq!(props.get("active"), Some(&Value::Boolean(true))); @@ -449,11 +463,11 @@ fn test_v_from_type_deleted_vectors_filtered() { // Insert two vectors with the same label let v1 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "delete_test", None) + .insert_v(&[1.0, 2.0, 3.0], "delete_test", None) .collect_to_obj() .unwrap(); let v2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[4.0, 5.0, 6.0], "delete_test", None) + .insert_v(&[4.0, 5.0, 6.0], "delete_test", None) .collect_to_obj() .unwrap(); txn.commit().unwrap(); @@ -506,11 +520,11 @@ fn test_v_from_type_with_edges_and_nodes() { // Create vectors and connect them to the node let v1 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 0.0, 0.0], "embedding", None) + .insert_v(&[1.0, 0.0, 0.0], "embedding", None) .collect_to_obj() .unwrap(); let v2 = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[0.0, 1.0, 0.0], "embedding", None) + .insert_v(&[0.0, 1.0, 0.0], "embedding", None) .collect_to_obj() .unwrap(); @@ -549,191 +563,6 @@ fn test_v_from_type_with_edges_and_nodes() { assert_eq!(from_node.len(), 2); } -#[test] -fn test_v_from_type_after_migration() { - use std::collections::HashMap; - use crate::protocol::value::Value; - use crate::helix_engine::storage_core::storage_migration::migrate; - - // Helper to create old-format vector properties (HashMap-based) - fn create_old_properties( - label: &str, - is_deleted: bool, - extra_props: HashMap, - ) -> Vec { - let mut props = HashMap::new(); - props.insert("label".to_string(), Value::String(label.to_string())); - props.insert("is_deleted".to_string(), Value::Boolean(is_deleted)); - - for (k, v) in extra_props { - props.insert(k, v); - } - - bincode::serialize(&props).unwrap() - } - - // Helper to clear metadata (simulates PreMetadata state) - fn clear_metadata(storage: &mut crate::helix_engine::storage_core::HelixGraphStorage) -> Result<(), crate::helix_engine::types::GraphError> { - let mut txn = storage.graph_env.write_txn()?; - storage.metadata_db.clear(&mut txn)?; - txn.commit()?; - Ok(()) - } - - let (_temp_dir, storage) = setup_test_db(); - let mut storage_mut = match Arc::try_unwrap(storage) { - Ok(s) => s, - Err(_) => panic!("Failed to unwrap Arc - there are multiple references"), - }; - - // Clear metadata to simulate PreMetadata state (before migration) - clear_metadata(&mut storage_mut).unwrap(); - - // Create old-format vectors with various properties - { - let mut txn = storage_mut.graph_env.write_txn().unwrap(); - - // Vector 1: Simple vector with test label - let mut props1 = HashMap::new(); - props1.insert("name".to_string(), Value::String("vector1".to_string())); - props1.insert("count".to_string(), Value::I64(100)); - let old_bytes1 = create_old_properties("test_migration", false, props1); - storage_mut - .vectors - .vector_properties_db - .put(&mut txn, &1u128, &old_bytes1) - .unwrap(); - - // Add actual vector data with proper key format - let vector_data1: Vec = vec![1.0, 2.0, 3.0]; - let bytes1: Vec = vector_data1.iter().flat_map(|f| f.to_be_bytes()).collect(); - let key1 = [b"v:".as_slice(), &1u128.to_be_bytes(), &0usize.to_be_bytes()].concat(); - storage_mut - .vectors - .vectors_db - .put(&mut txn, &key1, &bytes1) - .unwrap(); - - // Vector 2: Another vector with same label - let mut props2 = HashMap::new(); - props2.insert("name".to_string(), Value::String("vector2".to_string())); - props2.insert("score".to_string(), Value::F64(0.95)); - let old_bytes2 = create_old_properties("test_migration", false, props2); - storage_mut - .vectors - .vector_properties_db - .put(&mut txn, &2u128, &old_bytes2) - .unwrap(); - - // Add actual vector data with proper key format - let vector_data2: Vec = vec![4.0, 5.0, 6.0]; - let bytes2: Vec = vector_data2.iter().flat_map(|f| f.to_be_bytes()).collect(); - let key2 = [b"v:".as_slice(), &2u128.to_be_bytes(), &0usize.to_be_bytes()].concat(); - storage_mut - .vectors - .vectors_db - .put(&mut txn, &key2, &bytes2) - .unwrap(); - - // Vector 3: Different label - let mut props3 = HashMap::new(); - props3.insert("name".to_string(), Value::String("vector3".to_string())); - let old_bytes3 = create_old_properties("other_label", false, props3); - storage_mut - .vectors - .vector_properties_db - .put(&mut txn, &3u128, &old_bytes3) - .unwrap(); - - // Add actual vector data with proper key format - let vector_data3: Vec = vec![7.0, 8.0, 9.0]; - let bytes3: Vec = vector_data3.iter().flat_map(|f| f.to_be_bytes()).collect(); - let key3 = [b"v:".as_slice(), &3u128.to_be_bytes(), &0usize.to_be_bytes()].concat(); - storage_mut - .vectors - .vectors_db - .put(&mut txn, &key3, &bytes3) - .unwrap(); - - txn.commit().unwrap(); - } - - // Run migration - let result = migrate(&mut storage_mut); - assert!(result.is_ok(), "Migration should succeed"); - - // Now query using v_from_type on the migrated data - let storage = Arc::new(storage_mut); - let arena = Bump::new(); - let txn = storage.graph_env.read_txn().unwrap(); - - // Query for "test_migration" label - should find 2 vectors - let results_with_data = G::new(&storage, &txn, &arena) - .v_from_type("test_migration", true) - .collect::, _>>() - .unwrap(); - - assert_eq!(results_with_data.len(), 2, "Should find 2 vectors with test_migration label"); - - // Verify we got the right vectors - let ids: Vec = results_with_data.iter().map(|v| v.id()).collect(); - assert!(ids.contains(&1u128), "Should contain vector 1"); - assert!(ids.contains(&2u128), "Should contain vector 2"); - - // Verify vector data is accessible - if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) = &results_with_data[0] { - assert_eq!(v.data.len(), 3, "Vector should have 3 dimensions"); - } else { - panic!("Expected TraversalValue::Vector"); - } - - // Query without vector data to check properties - let arena2 = Bump::new(); - let results_without_data = G::new(&storage, &txn, &arena2) - .v_from_type("test_migration", false) - .collect::, _>>() - .unwrap(); - - assert_eq!(results_without_data.len(), 2, "Should still find 2 vectors"); - - // Verify properties are preserved after migration - for result in &results_without_data { - if let crate::helix_engine::traversal_core::traversal_value::TraversalValue::VectorNodeWithoutVectorData(v) = result { - assert_eq!(v.label(), "test_migration"); - - // Check that properties are accessible - let props = v.properties.as_ref().unwrap(); - let name = props.get("name"); - assert!(name.is_some(), "name property should exist"); - - // Verify it's a string - match name.unwrap() { - Value::String(s) => assert!(s == "vector1" || s == "vector2"), - _ => panic!("Expected name to be a string"), - } - } - } - - // Query for "other_label" - should find 1 vector - let arena3 = Bump::new(); - let other_results = G::new(&storage, &txn, &arena3) - .v_from_type("other_label", true) - .collect::, _>>() - .unwrap(); - - assert_eq!(other_results.len(), 1, "Should find 1 vector with other_label"); - assert_eq!(other_results[0].id(), 3u128); - - // Query for non-existent label after migration - let arena4 = Bump::new(); - let empty_results = G::new(&storage, &txn, &arena4) - .v_from_type("nonexistent", true) - .collect::, _>>() - .unwrap(); - - assert!(empty_results.is_empty(), "Should find no vectors with nonexistent label"); -} - // ============================================================================ // Error Tests for v_from_id // ============================================================================ @@ -790,7 +619,7 @@ fn test_v_from_id_with_deleted_vector() { // Create a vector let vector = G::new_mut(&storage, &arena, &mut txn) - .insert_v::(&[1.0, 2.0, 3.0], "test_vector", None) + .insert_v(&[1.0, 2.0, 3.0], "test_vector", None) .collect_to_obj() .unwrap(); let vector_id = vector.id(); diff --git a/helix-db/src/helix_engine/tests/vector_tests.rs b/helix-db/src/helix_engine/tests/vector_tests.rs index 4799902d6..bda172d8a 100644 --- a/helix-db/src/helix_engine/tests/vector_tests.rs +++ b/helix-db/src/helix_engine/tests/vector_tests.rs @@ -1,18 +1,19 @@ -use crate::helix_engine::vector_core::vector_distance::{MAX_DISTANCE, MIN_DISTANCE, ORTHOGONAL}; - -use crate::helix_engine::vector_core::vector::HVector; +use crate::helix_engine::vector_core::{ + HVector, + distance::{MAX_DISTANCE, MIN_DISTANCE, ORTHOGONAL}, +}; use bumpalo::Bump; fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { let slice = arena.alloc_slice_copy(data); - HVector::from_slice("vector", 0, slice) + HVector::from_slice("vector", 0, slice, arena) } #[test] fn test_hvector_from_slice() { let arena = Bump::new(); let vector = alloc_vector(&arena, &[1.0, 2.0, 3.0]); - assert_eq!(vector.data, &[1.0, 2.0, 3.0]); + assert_eq!(vector.data_borrowed(), &[1.0, 2.0, 3.0]); } #[test] diff --git a/helix-db/src/helix_engine/traversal_core/ops/in_/in_.rs b/helix-db/src/helix_engine/traversal_core/ops/in_/in_.rs index 92fe6d19a..b64a77dd7 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/in_/in_.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/in_/in_.rs @@ -88,7 +88,7 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr .vectors .get_vector_properties(self.txn, item_id, self.arena) { - return Some(Ok(TraversalValue::VectorNodeWithoutVectorData(vec))); + return Some(Ok(TraversalValue::Vector(vec))); } None } else { diff --git a/helix-db/src/helix_engine/traversal_core/ops/in_/to_v.rs b/helix-db/src/helix_engine/traversal_core/ops/in_/to_v.rs index 0c627a605..f0c3dc6c2 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/in_/to_v.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/in_/to_v.rs @@ -47,9 +47,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE item.to_node, self.arena, ) { - Ok(Some(vector)) => { - Some(Ok(TraversalValue::VectorNodeWithoutVectorData(vector))) - } + Ok(Some(vector)) => Some(Ok(TraversalValue::Vector(vector))), Ok(None) => None, Err(e) => Some(Err(GraphError::from(e))), } diff --git a/helix-db/src/helix_engine/traversal_core/ops/out/from_v.rs b/helix-db/src/helix_engine/traversal_core/ops/out/from_v.rs index a3753d69f..5b2fed9a8 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/out/from_v.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/out/from_v.rs @@ -53,7 +53,7 @@ where item.from_node, self.arena, ) { - Ok(Some(vector)) => TraversalValue::VectorNodeWithoutVectorData(vector), + Ok(Some(vector)) => TraversalValue::Vector(vector), Ok(None) => { return Some(Err(GraphError::from(VectorError::VectorNotFound( item.from_node.to_string(), diff --git a/helix-db/src/helix_engine/traversal_core/ops/out/out.rs b/helix-db/src/helix_engine/traversal_core/ops/out/out.rs index 4fcc6c3ae..eecf82dc5 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/out/out.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/out/out.rs @@ -88,7 +88,7 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr .vectors .get_vector_properties(self.txn, item_id, self.arena) { - return Some(Ok(TraversalValue::VectorNodeWithoutVectorData(vec))); + return Some(Ok(TraversalValue::Vector(vec))); } None } else { diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_id.rs b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_id.rs index 3de40fbce..4bfe07da9 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_id.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_id.rs @@ -66,7 +66,7 @@ where if vec.deleted { Err(GraphError::from(VectorError::VectorDeleted)) } else { - Ok(TraversalValue::VectorNodeWithoutVectorData(vec)) + Ok(TraversalValue::Vector(vec)) } } Ok(None) => Err(GraphError::from(VectorError::VectorNotFound( diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs index 76be4d092..cae3bbff7 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs @@ -1,8 +1,10 @@ use crate::helix_engine::{ - traversal_core::{LMDB_STRING_HEADER_LENGTH, traversal_iter::RoTraversalIterator, traversal_value::TraversalValue}, - types::{GraphError, VectorError}, - vector_core::{vector_without_data::VectorWithoutData}, - }; + traversal_core::{ + LMDB_STRING_HEADER_LENGTH, traversal_iter::RoTraversalIterator, + traversal_value::TraversalValue, + }, + types::{GraphError, VectorError}, +}; pub trait VFromTypeAdapter<'db, 'arena, 'txn>: Iterator, GraphError>> @@ -43,63 +45,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE .vector_properties_db .iter(self.txn) .unwrap() - .filter_map(move |item| { - if let Ok((id, value)) = item { - - - // get label via bytes directly - assert!( - value.len() >= LMDB_STRING_HEADER_LENGTH, - "value length does not contain header which means the `label` field was missing from the node on insertion" - ); - let length_of_label_in_lmdb = - u64::from_le_bytes(value[..LMDB_STRING_HEADER_LENGTH].try_into().unwrap()) as usize; - assert!( - value.len() >= length_of_label_in_lmdb + LMDB_STRING_HEADER_LENGTH, - "value length is not at least the header length plus the label length meaning there has been a corruption on node insertion" - ); - let label_in_lmdb = &value[LMDB_STRING_HEADER_LENGTH - ..LMDB_STRING_HEADER_LENGTH + length_of_label_in_lmdb]; - - - // get deleted via bytes directly - - // skip single byte for version - let version_index = length_of_label_in_lmdb + LMDB_STRING_HEADER_LENGTH; - - // get bool for deleted - let deleted_index = version_index + 1; - let deleted = value[deleted_index] == 1; - - if deleted { - return None; - } - - if label_in_lmdb == label_bytes { - let vector_without_data = VectorWithoutData::from_bincode_bytes(self.arena, value, id) - .map_err(|e| VectorError::ConversionError(e.to_string())) - .ok()?; - - if get_vector_data { - let mut vector = match self.storage.vectors.get_raw_vector_data(self.txn, id, label, self.arena) { - Ok(bytes) => bytes, - Err(VectorError::VectorDeleted) => return None, - Err(e) => return Some(Err(GraphError::from(e))), - }; - vector.expand_from_vector_without_data(vector_without_data); - return Some(Ok(TraversalValue::Vector(vector))); - } else { - return Some(Ok(TraversalValue::VectorNodeWithoutVectorData( - vector_without_data - ))); - } - } else { - return None; - } - - } - None - }); + .filter_map(move |item| todo!()); RoTraversalIterator { storage: self.storage, diff --git a/helix-db/src/helix_engine/traversal_core/ops/util/drop.rs b/helix-db/src/helix_engine/traversal_core/ops/util/drop.rs index b2f5ec86d..a69d5068d 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/util/drop.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/util/drop.rs @@ -42,12 +42,6 @@ where Ok(_) => Ok(()), Err(e) => Err(e), }, - TraversalValue::VectorNodeWithoutVectorData(vector) => { - match storage.drop_vector(txn, &vector.id) { - Ok(_) => Ok(()), - Err(e) => Err(e), - } - } TraversalValue::Empty => Ok(()), _ => Err(GraphError::ConversionError(format!( "Incorrect Type: {item:?}" diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs index 7aa0dd24d..8273a8713 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs @@ -1,7 +1,10 @@ use crate::helix_engine::{ traversal_core::{traversal_iter::RoTraversalIterator, traversal_value::TraversalValue}, types::GraphError, - vector_core::vector_distance::cosine_similarity, + vector_core::{ + distance::{Cosine, Distance}, + node::Item, + }, }; use itertools::Itertools; @@ -40,12 +43,16 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE K: TryInto, K::Error: std::fmt::Debug, { + let arena = bumpalo::Bump::new(); let iter = self .inner .filter_map(|v| match v { Ok(TraversalValue::Vector(mut v)) => { - let d = cosine_similarity(v.data, query).unwrap(); - v.set_distance(d); + let d = Cosine::distance( + v.data.as_ref().unwrap(), + &Item::::from(query, &arena), + ); + v.set_distance(d as f64); Some(v) } _ => None, @@ -56,10 +63,11 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE match self .storage .vectors - .get_vector_properties(self.txn, *item.id(), self.arena) + .get_vector_properties(self.txn, item.id, self.arena) { Ok(Some(vector_without_data)) => { - item.expand_from_vector_without_data(vector_without_data); + // todo! + // item.expand_from_vector_without_data(vector_without_data); Some(item) } diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs index 3c167ef1e..50f257a18 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs @@ -2,16 +2,15 @@ use crate::{ helix_engine::{ traversal_core::{traversal_iter::RwTraversalIterator, traversal_value::TraversalValue}, types::GraphError, - vector_core::{hnsw::HNSW, vector::HVector}, + vector_core::HVector, }, utils::properties::ImmutablePropertiesMap, }; -use heed3::RoTxn; pub trait InsertVAdapter<'db, 'arena, 'txn>: Iterator, GraphError>> { - fn insert_v( + fn insert_v( self, query: &'arena [f64], label: &'arena str, @@ -21,15 +20,13 @@ pub trait InsertVAdapter<'db, 'arena, 'txn>: 'arena, 'txn, impl Iterator, GraphError>>, - > - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool; + >; } impl<'db, 'arena, 'txn, I: Iterator, GraphError>>> InsertVAdapter<'db, 'arena, 'txn> for RwTraversalIterator<'db, 'arena, 'txn, I> { - fn insert_v( + fn insert_v( self, query: &'arena [f64], label: &'arena str, @@ -39,14 +36,11 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE 'arena, 'txn, impl Iterator, GraphError>>, - > - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - { + > { let vector: Result, crate::helix_engine::types::VectorError> = self .storage .vectors - .insert::(self.txn, label, query, properties, self.arena); + .insert(self.txn, label, query, properties, self.arena); let result = match vector { Ok(vector) => Ok(TraversalValue::Vector(vector)), diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index df8e619ea..1f414f7e8 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -3,7 +3,7 @@ use heed3::RoTxn; use crate::helix_engine::{ traversal_core::{traversal_iter::RoTraversalIterator, traversal_value::TraversalValue}, types::{GraphError, VectorError}, - vector_core::{hnsw::HNSW, vector::HVector}, + vector_core::HVector, }; use std::iter::once; @@ -53,7 +53,6 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE query, k.try_into().unwrap(), label, - filter, false, self.arena, ); @@ -61,6 +60,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE let iter = match vectors { Ok(vectors) => vectors .into_iter() + // copying here! .map(|vector| Ok::(TraversalValue::Vector(vector))) .collect::>() .into_iter(), diff --git a/helix-db/src/helix_engine/traversal_core/traversal_value.rs b/helix-db/src/helix_engine/traversal_core/traversal_value.rs index 2440b95e6..8f28ba02b 100644 --- a/helix-db/src/helix_engine/traversal_core/traversal_value.rs +++ b/helix-db/src/helix_engine/traversal_core/traversal_value.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::{ - helix_engine::vector_core::{vector::HVector, vector_without_data::VectorWithoutData}, + helix_engine::vector_core::HVector, protocol::value::Value, utils::items::{Edge, Node}, }; @@ -18,8 +18,6 @@ pub enum TraversalValue<'arena> { Edge(Edge<'arena>), /// A vector in the graph Vector(HVector<'arena>), - /// Vector node without vector data - VectorNodeWithoutVectorData(VectorWithoutData<'arena>), /// A count of the number of items /// A path between two nodes in the graph Path((Vec>, Vec>)), @@ -38,7 +36,6 @@ impl<'arena> TraversalValue<'arena> { TraversalValue::Node(node) => node.id, TraversalValue::Edge(edge) => edge.id, TraversalValue::Vector(vector) => vector.id, - TraversalValue::VectorNodeWithoutVectorData(vector) => vector.id, TraversalValue::Empty => 0, _ => 0, } @@ -49,7 +46,6 @@ impl<'arena> TraversalValue<'arena> { TraversalValue::Node(node) => node.label, TraversalValue::Edge(edge) => edge.label, TraversalValue::Vector(vector) => vector.label, - TraversalValue::VectorNodeWithoutVectorData(vector) => vector.label, TraversalValue::Empty => "", _ => "", } @@ -69,10 +65,9 @@ impl<'arena> TraversalValue<'arena> { } } - pub fn data(&self) -> &'arena [f64] { + pub fn data(&'arena self) -> &'arena [f64] { match self { - TraversalValue::Vector(vector) => vector.data, - TraversalValue::VectorNodeWithoutVectorData(_) => &[], + TraversalValue::Vector(vector) => vector.data_borrowed(), _ => unimplemented!(), } } @@ -80,7 +75,6 @@ impl<'arena> TraversalValue<'arena> { pub fn score(&self) -> f64 { match self { TraversalValue::Vector(vector) => vector.score(), - TraversalValue::VectorNodeWithoutVectorData(_) => 2f64, _ => unimplemented!(), } } @@ -90,7 +84,6 @@ impl<'arena> TraversalValue<'arena> { TraversalValue::Node(node) => node.label, TraversalValue::Edge(edge) => edge.label, TraversalValue::Vector(vector) => vector.label, - TraversalValue::VectorNodeWithoutVectorData(vector) => vector.label, TraversalValue::Empty => "", _ => "", } @@ -101,7 +94,6 @@ impl<'arena> TraversalValue<'arena> { TraversalValue::Node(node) => node.get_property(property), TraversalValue::Edge(edge) => edge.get_property(property), TraversalValue::Vector(vector) => vector.get_property(property), - TraversalValue::VectorNodeWithoutVectorData(vector) => vector.get_property(property), TraversalValue::Empty => None, _ => None, } @@ -114,7 +106,6 @@ impl Hash for TraversalValue<'_> { TraversalValue::Node(node) => node.id.hash(state), TraversalValue::Edge(edge) => edge.id.hash(state), TraversalValue::Vector(vector) => vector.id.hash(state), - TraversalValue::VectorNodeWithoutVectorData(vector) => vector.id.hash(state), TraversalValue::Empty => state.write_u8(0), _ => state.write_u8(0), } @@ -128,20 +119,8 @@ impl PartialEq for TraversalValue<'_> { (TraversalValue::Node(node1), TraversalValue::Node(node2)) => node1.id == node2.id, (TraversalValue::Edge(edge1), TraversalValue::Edge(edge2)) => edge1.id == edge2.id, (TraversalValue::Vector(vector1), TraversalValue::Vector(vector2)) => { - vector1.id() == vector2.id() + vector1.id == vector2.id } - ( - TraversalValue::VectorNodeWithoutVectorData(vector1), - TraversalValue::VectorNodeWithoutVectorData(vector2), - ) => vector1.id() == vector2.id(), - ( - TraversalValue::Vector(vector1), - TraversalValue::VectorNodeWithoutVectorData(vector2), - ) => vector1.id() == vector2.id(), - ( - TraversalValue::VectorNodeWithoutVectorData(vector1), - TraversalValue::Vector(vector2), - ) => vector1.id() == vector2.id(), (TraversalValue::Empty, TraversalValue::Empty) => true, _ => false, } diff --git a/helix-db/src/helix_engine/types.rs b/helix-db/src/helix_engine/types.rs index 5be25a1d7..ce0797dd8 100644 --- a/helix-db/src/helix_engine/types.rs +++ b/helix-db/src/helix_engine/types.rs @@ -1,4 +1,8 @@ -use crate::{helix_gateway::router::router::IoContFn, helixc::parser::errors::ParserError}; +use crate::{ + helix_engine::vector_core::{ItemId, LayerId, key::Key, node_id::NodeMode}, + helix_gateway::router::router::IoContFn, + helixc::parser::errors::ParserError, +}; use core::fmt; use heed3::Error as HeedError; use sonic_rs::Error as SonicError; @@ -30,8 +34,6 @@ pub enum GraphError { ParamNotFound(&'static str), IoNeeded(IoContFn), RerankerError(String), - - } impl std::error::Error for GraphError {} @@ -155,6 +157,31 @@ pub enum VectorError { ConversionError(String), VectorCoreError(String), VectorAlreadyDeleted(String), + InvalidVecDimension { + expected: usize, + received: usize, + }, + MissingKey { + /// The index that caused the error + index: u16, + /// The kind of item that was being queried + mode: &'static str, + /// The item ID queried + item: ItemId, + /// The item's layer + layer: LayerId, + }, + Io(String), + NeedBuild(u16), + /// The user is trying to query a database with a distance that is not of the right type. + UnmatchingDistance { + /// The expected distance type. + expected: String, + /// The distance given by the user. + received: &'static str, + }, + MissingMetadata(u16), + HasNoData, } impl std::error::Error for VectorError {} @@ -170,6 +197,50 @@ impl fmt::Display for VectorError { VectorError::ConversionError(msg) => write!(f, "Conversion error: {msg}"), VectorError::VectorCoreError(msg) => write!(f, "Vector core error: {msg}"), VectorError::VectorAlreadyDeleted(id) => write!(f, "Vector already deleted: {id}"), + VectorError::InvalidVecDimension { expected, received } => { + write!( + f, + "Invalid vector dimension: expected {expected}, received {received}" + ) + } + VectorError::MissingKey { + index, mode, item, .. + } => write!( + f, + "Internal error: {mode}({item}) is missing in index `{index}`" + ), + VectorError::Io(error) => write!(f, "IO error: {error}"), + VectorError::NeedBuild(idx) => write!( + f, + "The graph has not been built after an update on index {idx}" + ), + VectorError::UnmatchingDistance { expected, received } => { + write!( + f, + "Invalid distance provided. Got {received} but expected {expected}" + ) + } + VectorError::MissingMetadata(idx) => write!( + f, + "Metadata are missing on index {idx}, You must build your database before attempting to read it" + ), + VectorError::HasNoData => write!(f, "Trying to access data where there is none"), + } + } +} + +impl VectorError { + pub(crate) fn missing_key(key: Key) -> Self { + Self::MissingKey { + index: key.index, + mode: match key.node.mode { + NodeMode::Item => "Item", + NodeMode::Links => "Links", + NodeMode::Metadata => "Metadata", + NodeMode::Updated => "Updated", + }, + item: key.node.item, + layer: key.node.layer, } } } @@ -203,3 +274,9 @@ impl From for VectorError { VectorError::ConversionError(format!("bincode error: {error}")) } } + +impl From for VectorError { + fn from(error: std::io::Error) -> Self { + VectorError::Io(format!("Io Error: {error}")) + } +} diff --git a/helix-db/src/helix_engine/vector_core/binary_heap.rs b/helix-db/src/helix_engine/vector_core/binary_heap.rs deleted file mode 100644 index 5c802f1f5..000000000 --- a/helix-db/src/helix_engine/vector_core/binary_heap.rs +++ /dev/null @@ -1,567 +0,0 @@ -use core::mem::{ManuallyDrop, swap}; -use core::ptr; -use core::slice; -use std::iter::FusedIterator; -pub struct BinaryHeap<'arena, T> { - pub arena: &'arena bumpalo::Bump, - data: bumpalo::collections::Vec<'arena, T>, -} - -impl<'arena, T: Ord> BinaryHeap<'arena, T> { - pub fn new(arena: &'arena bumpalo::Bump) -> BinaryHeap<'arena, T> { - BinaryHeap { - arena, - data: bumpalo::collections::Vec::with_capacity_in(0, arena), - } - } - - pub fn with_capacity(arena: &'arena bumpalo::Bump, capacity: usize) -> BinaryHeap<'arena, T> { - BinaryHeap { - arena, - data: bumpalo::collections::Vec::with_capacity_in(capacity, arena), - } - } - - #[inline] - pub fn extend>(&mut self, iter: I) { - let guard = RebuildOnDrop { - rebuild_from: self.len(), - heap: self, - }; - guard.heap.data.extend(iter); - } - - pub fn pop(&mut self) -> Option { - self.data.pop().map(|mut item| { - if !self.is_empty() { - swap(&mut item, &mut self.data[0]); - // SAFETY: !self.is_empty() means that self.len() > 0 - unsafe { self.sift_down_to_bottom(0) }; - } - item - }) - } - - #[must_use] - pub fn peek(&self) -> Option<&T> { - self.data.first() - } - - pub fn from( - arena: &'arena bumpalo::Bump, - data: bumpalo::collections::Vec<'arena, T>, - ) -> BinaryHeap<'arena, T> { - BinaryHeap { arena, data } - } - - pub fn push(&mut self, item: T) { - let old_len = self.len(); - self.data.push(item); - // SAFETY: Since we pushed a new item it means that - // old_len = self.len() - 1 < self.len() - unsafe { self.sift_up(0, old_len) }; - } - - // The implementations of sift_up and sift_down use unsafe blocks in - // order to move an element out of the vector (leaving behind a - // hole), shift along the others and move the removed element back into the - // vector at the final location of the hole. - // The `Hole` type is used to represent this, and make sure - // the hole is filled back at the end of its scope, even on panic. - // Using a hole reduces the constant factor compared to using swaps, - // which involves twice as many moves. - - /// # Safety - /// - /// The caller must guarantee that `pos < self.len()`. - /// - /// Returns the new position of the element. - unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize { - // Take out the value at `pos` and create a hole. - // SAFETY: The caller guarantees that pos < self.len() - let mut hole = unsafe { Hole::new(&mut self.data, pos) }; - - while hole.pos() > start { - let parent = (hole.pos() - 1) / 2; - - // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0 - // and so hole.pos() - 1 can't underflow. - // This guarantees that parent < hole.pos() so - // it's a valid index and also != hole.pos(). - if hole.element() <= unsafe { hole.get(parent) } { - break; - } - - // SAFETY: Same as above - unsafe { hole.move_to(parent) }; - } - - hole.pos() - } - - /// Take an element at `pos` and move it down the heap, - /// while its children are larger. - /// - /// Returns the new position of the element. - /// - /// # Safety - /// - /// The caller must guarantee that `pos < end <= self.len()`. - unsafe fn sift_down_range(&mut self, pos: usize, end: usize) -> usize { - // SAFETY: The caller guarantees that pos < end <= self.len(). - let mut hole = unsafe { Hole::new(&mut self.data, pos) }; - let mut child = 2 * hole.pos() + 1; - - // Loop invariant: child == 2 * hole.pos() + 1. - while child <= end.saturating_sub(2) { - // compare with the greater of the two children - // SAFETY: child < end - 1 < self.len() and - // child + 1 < end <= self.len(), so they're valid indexes. - // child == 2 * hole.pos() + 1 != hole.pos() and - // child + 1 == 2 * hole.pos() + 2 != hole.pos(). - // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow - // if T is a ZST - child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize; - - // if we are already in order, stop. - // SAFETY: child is now either the old child or the old child+1 - // We already proven that both are < self.len() and != hole.pos() - if hole.element() >= unsafe { hole.get(child) } { - return hole.pos(); - } - - // SAFETY: same as above. - unsafe { hole.move_to(child) }; - child = 2 * hole.pos() + 1; - } - - // SAFETY: && short circuit, which means that in the - // second condition it's already true that child == end - 1 < self.len(). - if child == end - 1 && hole.element() < unsafe { hole.get(child) } { - // SAFETY: child is already proven to be a valid index and - // child == 2 * hole.pos() + 1 != hole.pos(). - unsafe { hole.move_to(child) }; - } - - hole.pos() - } - - /// # Safety - /// - /// The caller must guarantee that `pos < self.len()`. - unsafe fn sift_down(&mut self, pos: usize) -> usize { - let len = self.len(); - // SAFETY: pos < len is guaranteed by the caller and - // obviously len = self.len() <= self.len(). - unsafe { self.sift_down_range(pos, len) } - } - - /// Take an element at `pos` and move it all the way down the heap, - /// then sift it up to its position. - /// - /// Note: This is faster when the element is known to be large / should - /// be closer to the bottom. - /// - /// # Safety - /// - /// The caller must guarantee that `pos < self.len()`. - unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) { - let end = self.len(); - let start = pos; - - // SAFETY: The caller guarantees that pos < self.len(). - let mut hole = unsafe { Hole::new(&mut self.data, pos) }; - let mut child = 2 * hole.pos() + 1; - - // Loop invariant: child == 2 * hole.pos() + 1. - while child <= end.saturating_sub(2) { - // SAFETY: child < end - 1 < self.len() and - // child + 1 < end <= self.len(), so they're valid indexes. - // child == 2 * hole.pos() + 1 != hole.pos() and - // child + 1 == 2 * hole.pos() + 2 != hole.pos(). - // FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow - // if T is a ZST - child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize; - - // SAFETY: Same as above - unsafe { hole.move_to(child) }; - child = 2 * hole.pos() + 1; - } - - if child == end - 1 { - // SAFETY: child == end - 1 < self.len(), so it's a valid index - // and child == 2 * hole.pos() + 1 != hole.pos(). - unsafe { hole.move_to(child) }; - } - pos = hole.pos(); - drop(hole); - - // SAFETY: pos is the position in the hole and was already proven - // to be a valid index. - unsafe { self.sift_up(start, pos) }; - } - - /// Rebuild assuming data[0..start] is still a proper heap. - fn rebuild_tail(&mut self, start: usize) { - if start == self.len() { - return; - } - - let tail_len = self.len() - start; - - #[inline(always)] - fn log2_fast(x: usize) -> usize { - (usize::BITS - x.leading_zeros() - 1) as usize - } - - // `rebuild` takes O(self.len()) operations - // and about 2 * self.len() comparisons in the worst case - // while repeating `sift_up` takes O(tail_len * log(start)) operations - // and about 1 * tail_len * log_2(start) comparisons in the worst case, - // assuming start >= tail_len. For larger heaps, the crossover point - // no longer follows this reasoning and was determined empirically. - let better_to_rebuild = if start < tail_len { - true - } else if self.len() <= 2048 { - 2 * self.len() < tail_len * log2_fast(start) - } else { - 2 * self.len() < tail_len * 11 - }; - - if better_to_rebuild { - self.rebuild(); - } else { - for i in start..self.len() { - // SAFETY: The index `i` is always less than self.len(). - unsafe { self.sift_up(0, i) }; - } - } - } - - fn rebuild(&mut self) { - let mut n = self.len() / 2; - while n > 0 { - n -= 1; - // SAFETY: n starts from self.len() / 2 and goes down to 0. - // The only case when !(n < self.len()) is if - // self.len() == 0, but it's ruled out by the loop condition. - unsafe { self.sift_down(n) }; - } - } - - /// Moves all the elements of `other` into `self`, leaving `other` empty. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use std::collections::BinaryHeap; - /// - /// let mut a = BinaryHeap::from([-10, 1, 2, 3, 3]); - /// let mut b = BinaryHeap::from([-20, 5, 43]); - /// - /// a.append(&mut b); - /// - /// assert_eq!(a.into_sorted_vec(), [-20, -10, 1, 2, 3, 3, 5, 43]); - /// assert!(b.is_empty()); - /// ``` - pub fn append(&mut self, other: &mut Self) { - if self.len() < other.len() { - swap(self, other); - } - - let start = self.data.len(); - - self.data.append(&mut other.data); - - self.rebuild_tail(start); - } - - /// Returns the length of the binary heap. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use std::collections::BinaryHeap; - /// let heap = BinaryHeap::from([1, 3]); - /// - /// assert_eq!(heap.len(), 2); - /// ``` - #[must_use] - pub fn len(&self) -> usize { - self.data.len() - } - - /// Checks if the binary heap is empty. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use std::collections::BinaryHeap; - /// let mut heap = BinaryHeap::new(); - /// - /// assert!(heap.is_empty()); - /// - /// heap.push(3); - /// heap.push(5); - /// heap.push(1); - /// - /// assert!(!heap.is_empty()); - /// ``` - #[must_use] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Clears the binary heap, returning an iterator over the removed elements - /// in arbitrary order. If the iterator is dropped before being fully - /// consumed, it drops the remaining elements in arbitrary order. - /// - /// The returned iterator keeps a mutable borrow on the heap to optimize - /// its implementation. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use std::collections::BinaryHeap; - /// let mut heap = BinaryHeap::from([1, 3]); - /// - /// assert!(!heap.is_empty()); - /// - /// for x in heap.drain() { - /// println!("{x}"); - /// } - /// - /// assert!(heap.is_empty()); - /// ``` - #[inline] - pub fn drain(&'arena mut self) -> Drain<'arena, 'arena, T> { - Drain { - iter: self.data.drain(..), - } - } - - pub fn reserve(&mut self, additional: usize) { - self.data.reserve(additional); - } - - pub fn iter(&self) -> Iter<'_, T> { - Iter { - iter: self.data.iter(), - } - } -} - -/// Hole represents a hole in a slice i.e., an index without valid value -/// (because it was moved from or duplicated). -/// In drop, `Hole` will restore the slice by filling the hole -/// position with the value that was originally removed. -struct Hole<'a, T: 'a> { - data: &'a mut [T], - elt: ManuallyDrop, - pos: usize, -} - -impl<'a, T> Hole<'a, T> { - /// Creates a new `Hole` at index `pos`. - /// - /// Unsafe because pos must be within the data slice. - #[inline] - unsafe fn new(data: &'a mut [T], pos: usize) -> Self { - debug_assert!(pos < data.len()); - // SAFE: pos should be inside the slice - let elt = unsafe { ptr::read(data.get_unchecked(pos)) }; - Hole { - data, - elt: ManuallyDrop::new(elt), - pos, - } - } - - #[inline] - fn pos(&self) -> usize { - self.pos - } - - /// Returns a reference to the element removed. - #[inline] - fn element(&self) -> &T { - &self.elt - } - - /// Returns a reference to the element at `index`. - /// - /// Unsafe because index must be within the data slice and not equal to pos. - #[inline] - unsafe fn get(&self, index: usize) -> &T { - debug_assert!(index != self.pos); - debug_assert!(index < self.data.len()); - unsafe { self.data.get_unchecked(index) } - } - - /// Move hole to new location - /// - /// Unsafe because index must be within the data slice and not equal to pos. - #[inline] - unsafe fn move_to(&mut self, index: usize) { - debug_assert!(index != self.pos); - debug_assert!(index < self.data.len()); - unsafe { - let ptr = self.data.as_mut_ptr(); - let index_ptr: *const _ = ptr.add(index); - let hole_ptr = ptr.add(self.pos); - ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1); - } - self.pos = index; - } -} - -impl Drop for Hole<'_, T> { - #[inline] - fn drop(&mut self) { - // fill the hole again - unsafe { - let pos = self.pos; - ptr::copy_nonoverlapping(&*self.elt, self.data.get_unchecked_mut(pos), 1); - } - } -} - -#[derive(Debug)] -pub struct Drain<'a, 'arena, T: 'a> { - iter: bumpalo::collections::vec::Drain<'a, 'arena, T>, -} - -impl<'arena, T> Iterator for Drain<'_, 'arena, T> { - type Item = T; - - #[inline] - fn next(&mut self) -> Option { - self.iter.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'arena, T> DoubleEndedIterator for Drain<'_, 'arena, T> { - #[inline] - fn next_back(&mut self) -> Option { - self.iter.next_back() - } -} - -impl<'arena, T> FusedIterator for Drain<'_, 'arena, T> {} - -pub struct Iter<'a, T: 'a> { - iter: slice::Iter<'a, T>, -} - -impl<'a, T> Iterator for Iter<'a, T> { - type Item = &'a T; - - #[inline] - fn next(&mut self) -> Option<&'a T> { - self.iter.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - - #[inline] - fn last(self) -> Option<&'a T> { - self.iter.last() - } -} - -impl<'a, T> DoubleEndedIterator for Iter<'a, T> { - #[inline] - fn next_back(&mut self) -> Option<&'a T> { - self.iter.next_back() - } -} -impl FusedIterator for Iter<'_, T> {} - -struct RebuildOnDrop<'a, 'arena, T: Ord> { - heap: &'a mut BinaryHeap<'arena, T>, - rebuild_from: usize, -} - -impl<'arena, T: Ord> Drop for RebuildOnDrop<'_, 'arena, T> { - fn drop(&mut self) { - self.heap.rebuild_tail(self.rebuild_from); - } -} - -/// An owning iterator over the elements of a `BinaryHeap`. -/// -/// This `struct` is created by [`BinaryHeap::into_iter()`] -/// (provided by the [`IntoIterator`] trait). See its documentation for more. -/// -/// [`into_iter`]: BinaryHeap::into_iter -pub struct IntoIter<'arena, T> { - iter: bumpalo::collections::vec::IntoIter<'arena, T>, -} - -impl<'arena, T> Iterator for IntoIter<'arena, T> { - type Item = T; - - #[inline] - fn next(&mut self) -> Option { - self.iter.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'arena, T> DoubleEndedIterator for IntoIter<'arena, T> { - #[inline] - fn next_back(&mut self) -> Option { - self.iter.next_back() - } -} - -impl FusedIterator for IntoIter<'_, T> {} - -impl<'arena, T> IntoIterator for BinaryHeap<'arena, T> { - type Item = T; - type IntoIter = IntoIter<'arena, T>; - - /// Creates a consuming iterator, that is, one that moves each value out of - /// the binary heap in arbitrary order. The binary heap cannot be used - /// after calling this. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use std::collections::BinaryHeap; - /// let heap = BinaryHeap::from([1, 2, 3, 4]); - /// - /// // Print 1, 2, 3, 4 in arbitrary order - /// for x in heap.into_iter() { - /// // x has type i32, not &i32 - /// println!("{x}"); - /// } - /// ``` - fn into_iter(self) -> IntoIter<'arena, T> { - IntoIter { - iter: self.data.into_iter(), - } - } -} diff --git a/helix-db/src/helix_engine/vector_core/distance/cosine.rs b/helix-db/src/helix_engine/vector_core/distance/cosine.rs new file mode 100644 index 000000000..cdb1fa51f --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/distance/cosine.rs @@ -0,0 +1,65 @@ +use std::fmt; + +use bytemuck::{Pod, Zeroable}; +use serde::Serialize; + +use crate::helix_engine::vector_core::{ + distance::Distance, node::Item, spaces::simple::dot_product, unaligned_vector::UnalignedVector, +}; + +/// The Cosine similarity is a measure of similarity between two +/// non-zero vectors defined in an inner product space. Cosine similarity +/// is the cosine of the angle between the vectors. +#[derive(Debug, Serialize, Clone)] +pub enum Cosine {} + +/// The header of Cosine item nodes. +#[repr(C)] +#[derive(Pod, Serialize, Zeroable, Clone, Copy)] +pub struct NodeHeaderCosine { + norm: f32, +} +impl fmt::Debug for NodeHeaderCosine { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NodeHeaderCosine") + .field("norm", &format!("{:.4}", self.norm)) + .finish() + } +} + +impl Distance for Cosine { + type Header = NodeHeaderCosine; + type VectorCodec = f32; + + fn name() -> &'static str { + "cosine" + } + + fn new_header(vector: &UnalignedVector) -> Self::Header { + NodeHeaderCosine { + norm: Self::norm_no_header(vector), + } + } + + fn distance(p: &Item, q: &Item) -> f32 { + let pn = p.header.norm; + let qn = q.header.norm; + let pq = dot_product(&p.vector, &q.vector); + let pnqn = pn * qn; + if pnqn > f32::EPSILON { + let cos = pq / pnqn; + let cos = cos.clamp(-1.0, 1.0); + // cos is [-1; 1] + // cos = 0. -> 0.5 + // cos = -1. -> 1.0 + // cos = 1. -> 0.0 + (1.0 - cos) / 2.0 + } else { + 0.0 + } + } + + fn norm_no_header(v: &UnalignedVector) -> f32 { + dot_product(v, v).sqrt() + } +} diff --git a/helix-db/src/helix_engine/vector_core/distance/mod.rs b/helix-db/src/helix_engine/vector_core/distance/mod.rs new file mode 100644 index 000000000..6ea16cf88 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/distance/mod.rs @@ -0,0 +1,42 @@ +use core::fmt; + +use bytemuck::{Pod, Zeroable}; + +use crate::helix_engine::vector_core::{ + node::Item, + unaligned_vector::{UnalignedVector, VectorCodec}, +}; + +pub use cosine::{Cosine, NodeHeaderCosine}; + +mod cosine; + +pub type DistanceValue = f32; + +pub const MAX_DISTANCE: f64 = 2.0; +pub const ORTHOGONAL: f64 = 1.0; +pub const MIN_DISTANCE: f64 = 0.0; + +pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static { + /// A header structure with informations related to the + type Header: Pod + Zeroable + fmt::Debug; + type VectorCodec: VectorCodec; + + /// The name of the distance. + /// + /// Note that the name is used to identify the distance and will help some performance improvements. + /// For example, the "cosine" distance is matched against the "binary quantized cosine" to avoid + /// recomputing links when moving from the former to the latter distance. + fn name() -> &'static str; + + fn new_header(vector: &UnalignedVector) -> Self::Header; + + /// Returns a non-normalized distance. + fn distance(p: &Item, q: &Item) -> DistanceValue; + + fn norm(item: &Item) -> f32 { + Self::norm_no_header(&item.vector) + } + + fn norm_no_header(v: &UnalignedVector) -> f32; +} diff --git a/helix-db/src/helix_engine/vector_core/hnsw.rs b/helix-db/src/helix_engine/vector_core/hnsw.rs index e110f2489..efa75cdb4 100644 --- a/helix-db/src/helix_engine/vector_core/hnsw.rs +++ b/helix-db/src/helix_engine/vector_core/hnsw.rs @@ -1,63 +1,564 @@ -use crate::helix_engine::vector_core::vector::HVector; -use crate::{helix_engine::types::VectorError, utils::properties::ImmutablePropertiesMap}; - -use heed3::{RoTxn, RwTxn}; - -pub trait HNSW { - /// Search for the k nearest neighbors of a query vector - /// - /// # Arguments - /// - /// * `txn` - The transaction to use - /// * `query` - The query vector - /// * `k` - The number of nearest neighbors to search for - /// - /// # Returns - /// - /// A vector of tuples containing the id and distance of the nearest neighbors - fn search<'db, 'arena, 'txn, F>( - &'db self, - txn: &'txn RoTxn<'db>, - query: &'arena [f64], - k: usize, - label: &'arena str, - filter: Option<&'arena [F]>, - should_trickle: bool, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> +use core::fmt; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::marker::PhantomData; +use std::{borrow::Cow, fmt::Debug}; + +use heed3::RwTxn; +use min_max_heap::MinMaxHeap; +use papaya::HashMap; +use rand::Rng; +use rand::distr::Distribution; +use rand::distr::weighted::WeightedIndex; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use roaring::RoaringBitmap; +use tinyvec::{ArrayVec, array_vec}; + +use crate::helix_engine::vector_core::node::{Item, Node}; +use crate::helix_engine::vector_core::{ + CoreDatabase, ItemId, + distance::Distance, + key::Key, + node::Links, + ordered_float::OrderedFloat, + parallel::{ImmutableItems, ImmutableLinks}, + stats::BuildStats, + writer::{BuildOption, FrozenReader}, +}; +use crate::helix_engine::vector_core::{VectorCoreResult, VectorError}; + +pub(crate) type ScoredLink = (OrderedFloat, ItemId); + +pub struct NodeState { + links: ArrayVec<[ScoredLink; M]>, +} + +impl Debug for NodeState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // from [crate::unaligned_vector] + struct Number(f32); + impl fmt::Debug for Number { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:0.3}", self.0) + } + } + let mut list = f.debug_list(); + + for &(OrderedFloat(dist), id) in &self.links { + let tup = (id, Number(dist)); + list.entry(&tup); + } + + list.finish() + } +} + +pub struct HnswBuilder { + assign_probas: Vec, + ef_construction: usize, + alpha: f32, + pub max_level: usize, + pub entry_points: Vec, + pub layers: Vec>>, + distance: PhantomData, +} + +impl HnswBuilder { + pub fn new(opts: &BuildOption) -> Self { + let assign_probas = Self::get_default_probas(); + Self { + assign_probas, + ef_construction: opts.ef_construction, + alpha: opts.alpha, + max_level: 0, + entry_points: Vec::new(), + layers: vec![], + distance: PhantomData, + } + } + + pub fn with_entry_points(mut self, entry_points: Vec) -> Self { + self.entry_points = entry_points; + self + } + + pub fn with_max_level(mut self, max_level: usize) -> Self { + self.max_level = max_level; + self + } + + // can probably even be u8's ... + fn get_random_level(&mut self, rng: &mut R) -> usize where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - 'db: 'arena, - 'arena: 'txn; - - /// Insert a new vector into the index - /// - /// # Arguments - /// - /// * `txn` - The transaction to use - /// * `data` - The vector data - /// - /// # Returns - /// - /// An HVector of the data inserted - fn insert<'db, 'arena, 'txn, F>( - &'db self, - txn: &'txn mut RwTxn<'db>, - label: &'arena str, - data: &'arena [f64], - properties: Option>, - arena: &'arena bumpalo::Bump, - ) -> Result, VectorError> + R: Rng + ?Sized, + { + let dist = WeightedIndex::new(&self.assign_probas).unwrap(); + dist.sample(rng) + } + + fn get_default_probas() -> Vec { + let mut assign_probas = Vec::with_capacity(M); + let level_factor = 1.0 / (M as f32 + f32::EPSILON).ln(); + let mut level = 0; + loop { + // P(L( + &mut self, + mut to_insert: RoaringBitmap, + to_delete: &RoaringBitmap, + database: CoreDatabase, + index: u16, + wtxn: &mut RwTxn, + rng: &mut R, + ) -> VectorCoreResult> where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - 'db: 'arena, - 'arena: 'txn; - - /// Delete a vector from the index - /// - /// # Arguments - /// - /// * `txn` - The transaction to use - /// * `id` - The id of the vector - fn delete(&self, txn: &mut RwTxn, id: u128, arena: &bumpalo::Bump) -> Result<(), VectorError>; + R: Rng + ?Sized, + { + let mut build_stats = BuildStats::new(); + + let items = ImmutableItems::new(wtxn, database, index)?; + let links = ImmutableLinks::new(wtxn, database, index, database.len(wtxn)?)?; + let lmdb = FrozenReader { + index, + items: &items, + links: &links, + }; + + // Generate a random level for each point + let mut cur_max_level = usize::MIN; + let mut levels: Vec<_> = to_insert + .iter() + .map(|item_id| { + let level = self.get_random_level(rng); + cur_max_level = cur_max_level.max(level); + (item_id, level) + }) + .collect(); + + let ok_eps = + self.prepare_levels_and_entry_points(&mut levels, cur_max_level, to_delete, &lmdb)?; + to_insert |= ok_eps; + + let level_groups: Vec<_> = levels.chunk_by(|(_, la), (_, lb)| la == lb).collect(); + + // Insert layers L...0 multi-threaded + level_groups.into_iter().try_for_each(|grp| { + grp.into_par_iter().try_for_each(|&(item_id, lvl)| { + self.insert(item_id, lvl, &lmdb, &build_stats)?; + Ok(()) as Result<(), VectorError> + })?; + + build_stats.layer_dist.insert(grp[0].1, grp.len()); + + Ok(()) as Result<(), VectorError> + })?; + + self.maybe_patch_old_links(&lmdb, to_delete)?; + + // Single-threaded write to lmdb + for lvl in 0..=self.max_level { + let Some(map) = self.layers.get(lvl) else { + break; + }; + let map_guard = map.pin(); + + for (item_id, node_state) in &map_guard { + let key = Key::links(index, *item_id, lvl as u8); + let links = Links { + links: Cow::Owned(RoaringBitmap::from_iter( + node_state.links.iter().map(|(_, i)| *i), + )), + }; + + database.put(wtxn, &key, &Node::Links(links))?; + } + } + + build_stats.compute_mean_degree(wtxn, &database, index)?; + Ok(build_stats) + } + + fn prepare_levels_and_entry_points( + &mut self, + levels: &mut Vec<(u32, usize)>, + cur_max_level: usize, + to_delete: &RoaringBitmap, + lmdb: &FrozenReader, + ) -> VectorCoreResult { + let old_eps = RoaringBitmap::from_iter(self.entry_points.iter()); + let mut ok_eps = &old_eps - to_delete; + + // If any old entry points were deleted we need to replace them + for _ in (old_eps & to_delete).iter() { + let mut l = self.max_level; + loop { + for result in lmdb.links.iter_layer(l as u8) { + let ((item_id, _), _) = result?; + + if !to_delete.contains(item_id) && ok_eps.insert(item_id) { + break; + } + } + + // no points found in layer, continue to next one + l = match l.checked_sub(1) { + Some(new_level) => new_level, + None => break, + }; + } + } + // If the loop above added no points, we must have deleted the entire prev graph! + if ok_eps.is_empty() { + self.max_level = 0; + } + + // Schedule old entry point ids for re-indexing, otherwise we end up building a completely + // isolated sub-graph. + levels.extend(ok_eps.iter().map(|id| (id, self.max_level))); + + if cur_max_level > self.max_level { + self.entry_points.clear(); + } + + self.max_level = self.max_level.max(cur_max_level); + for _ in 0..=self.max_level { + self.layers.push(HashMap::new()); + } + + levels.sort_unstable_by(|(_, a), (_, b)| b.cmp(a)); + + let upper_layer: Vec<_> = levels + .iter() + .take_while(|(_, l)| *l == self.max_level) + .filter(|&(item_id, _)| !self.entry_points.contains(item_id)) + .collect(); + + for &(item_id, _) in upper_layer { + ok_eps.insert(item_id); + self.add_in_layers_below(item_id, self.max_level); + } + + self.entry_points = ok_eps.iter().collect(); + Ok(ok_eps) + } + + fn insert( + &self, + query: ItemId, + level: usize, + lmdb: &FrozenReader<'_, D>, + build_stats: &BuildStats, + ) -> VectorCoreResult<()> { + let mut eps = Vec::from_iter(self.entry_points.clone()); + + let q = lmdb.get_item(query)?; + + // Greedy search with: ef = 1 + for lvl in (level + 1..=self.max_level).rev() { + let neighbours = self.walk_layer(&q, &eps, lvl, 1, lmdb, build_stats)?; + let closest = neighbours + .peek_min() + .map(|(_, n)| *n) + .expect("No neighbor was found"); + eps = vec![closest]; + } + + self.add_in_layers_below(query, level); + + // Beam search with: ef = ef_construction + for lvl in (0..=level).rev() { + let neighbours = self + .walk_layer(&q, &eps, lvl, self.ef_construction, lmdb, build_stats)? + .into_vec(); + + eps.clear(); + for (dist, n) in self.robust_prune(neighbours, level, self.alpha, lmdb)? { + // add links in both directions + self.add_link(query, (dist, n), lvl, lmdb)?; + self.add_link(n, (dist, query), lvl, lmdb)?; + eps.push(n); + + build_stats.incr_link_count(2); + } + } + + Ok(()) + } + + /// During incremental updates we store a working copy of potential links to the new items. At + /// the end of indexing we need to merge the old and new links and prune ones pointing to + /// deleted items. + /// Algorithm 4 from FreshDiskANN paper. + fn maybe_patch_old_links( + &mut self, + lmdb: &FrozenReader, + to_delete: &RoaringBitmap, + ) -> VectorCoreResult<()> { + let links_in_db: Vec<_> = lmdb + .links + .iter() + .map(|result| { + result.map(|((id, lvl), v)| { + // Resize the layers if necessary. We must do this to accomodate links from + // previous builds that exist on levels larger than our current one. + if self.layers.len() <= lvl as usize { + self.layers.resize_with(lvl as usize + 1, HashMap::new); + } + ((id, lvl as usize), v.into_owned()) + }) + }) + .collect(); + + links_in_db.into_par_iter().try_for_each(|result| { + let ((id, lvl), links) = result?; + + // Since we delete links AFTER a build (we need to do this to apply diskann-approach + // for patching), links belonging to deleted items may still be present. We don't + // care about patching them. + if to_delete.contains(id) { + return Ok(()) as Result<(), VectorError>; + } + let del_subset = &links & to_delete; + + // This is safe because we resized layers above. + let map_guard = self.layers[lvl].pin(); + let mut new_links = map_guard + .get(&id) + .map(|s| s.links.to_vec()) + .unwrap_or_default(); + + // No work to be done, continue + if del_subset.is_empty() && new_links.is_empty() { + return Ok(()); + } + + // Iter through each of the deleted, and explore his neighbours + let mut bitmap = RoaringBitmap::new(); + for item_id in del_subset.iter() { + bitmap.extend(lmdb.get_links(item_id, lvl)?.iter()); + } + bitmap |= links; + bitmap -= to_delete; + + // TODO: abstract this layer search and pruning bit as its duplicated a lot in + // this file + for other in bitmap { + let dist = D::distance(&lmdb.get_item(id)?, &lmdb.get_item(other)?); + new_links.push((OrderedFloat(dist), other)); + } + let pruned = self.robust_prune(new_links, lvl, self.alpha, lmdb)?; + let _ = map_guard.insert( + id, + NodeState { + links: ArrayVec::from_iter(pruned), + }, + ); + Ok(()) + })?; + + Ok(()) + } + + /// Rather than simply insert, we'll make it a no-op so we can re-insert the same item without + /// overwriting it's links in mem. This is useful in cases like Vanama build. + fn add_in_layers_below(&self, item_id: ItemId, level: usize) { + for level in 0..=level { + let Some(map) = self.layers.get(level) else { + break; + }; + map.pin().get_or_insert( + item_id, + NodeState { + links: array_vec![], + }, + ); + } + } + + /// Returns only the Id's of our neighbours. Always check lmdb first. + fn get_neighbours( + &self, + lmdb: &FrozenReader<'_, D>, + item_id: ItemId, + level: usize, + build_stats: &BuildStats, + ) -> VectorCoreResult> { + let mut res = Vec::new(); + + // O(1) from frozzenreader + if let Ok(Links { links }) = lmdb.get_links(item_id, level) { + build_stats.incr_lmdb_hits(); + res.extend(links.iter()); + } + + // O(1) from self.layers + let Some(map) = self.layers.get(level) else { + return Ok(res); + }; + match map.pin().get(&item_id) { + Some(node_state) => res.extend(node_state.links.iter().map(|(_, i)| *i)), + None => { + if res.is_empty() { + build_stats.incr_link_misses(); + } + } + } + + Ok(res) + } + + #[allow(clippy::too_many_arguments)] + fn walk_layer( + &self, + query: &Item, + eps: &[ItemId], + level: usize, + ef: usize, + lmdb: &FrozenReader<'_, D>, + build_stats: &BuildStats, + ) -> VectorCoreResult> { + let mut candidates = BinaryHeap::new(); + let mut res = MinMaxHeap::with_capacity(ef); + let mut visited = RoaringBitmap::new(); + + // Register all entry points as visited and populate candidates + for &ep in eps { + let ve = lmdb.get_item(ep)?; + let dist = D::distance(query, &ve); + + candidates.push((Reverse(OrderedFloat(dist)), ep)); + res.push((OrderedFloat(dist), ep)); + visited.insert(ep); + } + + while let Some(&(Reverse(OrderedFloat(f)), _)) = candidates.peek() { + let &(OrderedFloat(f_max), _) = res.peek_max().unwrap(); + if f > f_max { + break; + } + let (_, c) = candidates.pop().unwrap(); // Now safe to pop + + // Get neighborhood of candidate either from self or LMDB + let proximity = self.get_neighbours(lmdb, c, level, build_stats)?; + for point in proximity { + if !visited.insert(point) { + continue; + } + // If the item isn't in the frozzen reader it must have been deleted from the index, + // in which case its OK not to explore it + let item = match lmdb.get_item(point) { + Ok(item) => item, + Err(VectorError::MissingKey { .. }) => continue, + Err(e) => return Err(e), + }; + let dist = D::distance(query, &item); + + if res.len() < ef || dist < f_max { + candidates.push((Reverse(OrderedFloat(dist)), point)); + + if res.len() == ef { + let _ = res.push_pop_max((OrderedFloat(dist), point)); + } else { + res.push((OrderedFloat(dist), point)); + } + } + } + } + + Ok(res) + } + + /// Tries to add a new link between nodes in a single direction. + // TODO: prevent duplicate links the other way. I think this arises ONLY for entrypoints since + // we pre-emptively add them in each layer before + fn add_link( + &self, + p: ItemId, + q: ScoredLink, + level: usize, + lmdb: &FrozenReader<'_, D>, + ) -> VectorCoreResult<()> { + if p == q.1 { + return Ok(()); + } + + let Some(map) = self.layers.get(level) else { + return Ok(()); + }; + let map_guard = map.pin(); + + // 'pure' links update function + let _add_link = |node_state: &NodeState| { + let mut links = node_state.links; + let cap = if level == 0 { M0 } else { M }; + + if links.len() < cap { + links.push(q); + return NodeState { links }; + } + + let new_links = self + .robust_prune(links.to_vec(), level, self.alpha, lmdb) + .map(ArrayVec::from_iter) + .unwrap_or_else(|_| node_state.links); + + NodeState { links: new_links } + }; + + map_guard.update_or_insert_with(p, _add_link, || NodeState { + links: array_vec!([ScoredLink; M0] => q), + }); + + Ok(()) + } + + /// Naively choosing the nearest neighbours performs poorly on clustered data since we can never + /// escape our local neighbourhood. "Sparse Neighbourhood Graph" (SNG) condition sufficient for + /// quick convergence. + fn robust_prune( + &self, + mut candidates: Vec, + level: usize, + alpha: f32, + lmdb: &FrozenReader<'_, D>, + ) -> VectorCoreResult> { + let cap = if level == 0 { M0 } else { M }; + candidates.sort_by(|a, b| b.cmp(a)); + let mut selected: Vec = Vec::with_capacity(cap); + + while let Some((dist_to_query, c)) = candidates.pop() { + if selected.len() == cap { + break; + } + + // ensure we're closer to the query than we are to other candidates + let mut ok_to_add = true; + for i in selected.iter().map(|(_, i)| *i) { + let d = D::distance(&lmdb.get_item(c)?, &lmdb.get_item(i)?); + if OrderedFloat(d * alpha) < dist_to_query { + ok_to_add = false; + break; + } + } + + if ok_to_add { + selected.push((dist_to_query, c)); + } + } + + Ok(selected) + } } diff --git a/helix-db/src/helix_engine/vector_core/item_iter.rs b/helix-db/src/helix_engine/vector_core/item_iter.rs new file mode 100644 index 000000000..e6fccebe0 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/item_iter.rs @@ -0,0 +1,56 @@ +use heed3::RoTxn; + +use crate::helix_engine::vector_core::{ + CoreDatabase, ItemId, LmdbResult, + distance::Distance, + key::{KeyCodec, Prefix, PrefixCodec}, + node::{Item, Node, NodeCodec}, +}; + +// used by the reader +pub struct ItemIter<'t, D: Distance> { + pub inner: heed3::RoPrefix<'t, KeyCodec, NodeCodec>, + dimensions: usize, + arena: &'t bumpalo::Bump, +} + +impl<'t, D: Distance> ItemIter<'t, D> { + pub fn new( + database: CoreDatabase, + index: u16, + dimensions: usize, + rtxn: &'t RoTxn, + arena: &'t bumpalo::Bump, + ) -> heed3::Result { + Ok(ItemIter { + inner: database + .remap_key_type::() + .prefix_iter(rtxn, &Prefix::item(index))? + .remap_key_type::(), + dimensions, + arena, + }) + } +} + +impl<'t, D: Distance> Iterator for ItemIter<'t, D> { + type Item = LmdbResult<(ItemId, bumpalo::collections::Vec<'t, f32>)>; + + fn next(&mut self) -> Option { + match self.inner.next() { + Some(Ok((key, node))) => match node { + Node::Item(Item { header: _, vector }) => { + let mut vector = vector.to_vec(&self.arena); + if vector.len() != self.dimensions { + // quantized codecs pad to 8-bytes so we truncate to recover len + vector.truncate(self.dimensions); + } + Some(Ok((key.node.item, vector))) + } + Node::Links(_) => unreachable!("Node must not be a link"), + }, + Some(Err(e)) => Some(Err(e.into())), + None => None, + } + } +} diff --git a/helix-db/src/helix_engine/vector_core/key.rs b/helix-db/src/helix_engine/vector_core/key.rs new file mode 100644 index 000000000..8d32ff318 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/key.rs @@ -0,0 +1,174 @@ +use std::borrow::Cow; +use std::mem::size_of; + +use byteorder::{BigEndian, ByteOrder}; +use heed3::BoxedError; + +use crate::helix_engine::vector_core::node_id::{NodeId, NodeMode}; + +/// This whole structure must fit in an u64 so we can tell LMDB to optimize its storage. +/// The `index` is specified by the user and is used to differentiate between multiple hannoy indexes. +/// The `mode` indicates what we're looking at. +/// The `item` point to a specific node. +/// If the mode is: +/// - `Item`: we're looking at an `Item` node. +/// - `Links`: we're looking at the `Links` bitmap of neighbours for a node +/// - `Updated`: The list of items that has been updated since the last build of the database. +/// - `Metadata`: There is only one item at `0` that contains the header required to read the index. +#[derive(Debug, Copy, Clone)] +pub struct Key { + /// The prefix specified by the user. + pub index: u16, + pub node: NodeId, +} + +impl Key { + pub const fn new(index: u16, node: NodeId) -> Self { + Self { index, node } + } + + pub const fn metadata(index: u16) -> Self { + Self::new(index, NodeId::metadata()) + } + + pub const fn version(index: u16) -> Self { + Self::new(index, NodeId::version()) + } + + pub const fn updated(index: u16, item: u32) -> Self { + Self::new(index, NodeId::updated(item)) + } + + pub const fn item(index: u16, item: u32) -> Self { + Self::new(index, NodeId::item(item)) + } + + pub const fn links(index: u16, item: u32, layer: u8) -> Self { + Self::new(index, NodeId::links(item, layer)) + } +} + +/// The heed codec used internally to encode/decoding the internal key type. +pub enum KeyCodec {} + +impl<'a> heed3::BytesEncode<'a> for KeyCodec { + type EItem = Key; + + fn bytes_encode(item: &'a Self::EItem) -> Result, BoxedError> { + let mut output = Vec::with_capacity(size_of::()); + output.extend_from_slice(&item.index.to_be_bytes()); + output.extend_from_slice(&(item.node.mode as u8).to_be_bytes()); + output.extend_from_slice(&item.node.item.to_be_bytes()); + output.extend_from_slice(&(item.node.layer).to_be_bytes()); + + Ok(Cow::Owned(output)) + } +} + +impl heed3::BytesDecode<'_> for KeyCodec { + type DItem = Key; + + fn bytes_decode(bytes: &[u8]) -> Result { + let prefix = BigEndian::read_u16(bytes); + let bytes = &bytes[size_of::()..]; + let mode = bytes[0].try_into()?; + let bytes = &bytes[size_of::()..]; + let item = BigEndian::read_u32(bytes); + let bytes = &bytes[size_of::()..]; + let layer = bytes[0]; + + Ok(Key { + index: prefix, + node: NodeId { mode, item, layer }, + }) + } +} + +/// This is used to query part of a key. +#[derive(Debug, Copy, Clone)] +pub struct Prefix { + /// The index specified by the user. + index: u16, + // Indicate what the item represent. + mode: Option, +} + +impl Prefix { + pub const fn all(index: u16) -> Self { + Self { index, mode: None } + } + + pub const fn item(index: u16) -> Self { + Self { + index, + mode: Some(NodeMode::Item), + } + } + + pub const fn links(index: u16) -> Self { + Self { + index, + mode: Some(NodeMode::Links), + } + } + + pub const fn updated(index: u16) -> Self { + Self { + index, + mode: Some(NodeMode::Updated), + } + } +} + +pub enum PrefixCodec {} + +impl<'a> heed3::BytesEncode<'a> for PrefixCodec { + type EItem = Prefix; + + fn bytes_encode(item: &'a Self::EItem) -> Result, BoxedError> { + let mode_used = item.mode.is_some() as usize; + let mut output = Vec::with_capacity(size_of::() + mode_used); + + output.extend_from_slice(&item.index.to_be_bytes()); + if let Some(mode) = item.mode { + output.extend_from_slice(&(mode as u8).to_be_bytes()); + } + + Ok(Cow::Owned(output)) + } +} + +#[cfg(test)] +mod test { + use heed3::{BytesDecode, BytesEncode}; + + use super::*; + + #[test] + fn check_size_of_types() { + let key = Key::metadata(0); + let encoded = KeyCodec::bytes_encode(&key).unwrap(); + assert_eq!(encoded.len(), size_of::()); + } + + // TODO: fuzz this + #[test] + fn test_links_key() { + let key = Key::links(0, 1, 42); + let bytes = KeyCodec::bytes_encode(&key).unwrap(); + let key2 = KeyCodec::bytes_decode(&bytes).unwrap(); + assert_eq!(key.node.item, key2.node.item); + assert_eq!(key.node.layer, key2.node.layer); + assert_eq!(key.node.mode, key2.node.mode); + } + + #[test] + fn test_item_key() { + let key = Key::item(0, 42); + let bytes = KeyCodec::bytes_encode(&key).unwrap(); + let key2 = KeyCodec::bytes_decode(&bytes).unwrap(); + assert_eq!(key.node.item, key2.node.item); + assert_eq!(key.node.layer, key2.node.layer); + assert_eq!(key.node.mode, key2.node.mode); + } +} diff --git a/helix-db/src/helix_engine/vector_core/metadata.rs b/helix-db/src/helix_engine/vector_core/metadata.rs new file mode 100644 index 000000000..a0c21645a --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/metadata.rs @@ -0,0 +1,75 @@ +use std::{borrow::Cow, ffi::CStr}; + +use byteorder::{BigEndian, ByteOrder}; +use heed3::BoxedError; +use roaring::RoaringBitmap; + +use crate::helix_engine::vector_core::node::ItemIds; + +#[derive(Debug)] +pub struct Metadata<'a> { + pub dimensions: u32, + pub items: RoaringBitmap, + pub distance: &'a str, + pub entry_points: ItemIds<'a>, + pub max_level: u8, +} + +pub enum MetadataCodec {} + +impl<'a> heed3::BytesEncode<'a> for MetadataCodec { + type EItem = Metadata<'a>; + + fn bytes_encode(item: &'a Self::EItem) -> Result, BoxedError> { + let Metadata { + dimensions, + items, + entry_points, + distance, + max_level, + } = item; + debug_assert!(!distance.as_bytes().iter().any(|&b| b == 0)); + + let mut output = Vec::with_capacity( + size_of::() + + items.serialized_size() + + entry_points.len() * size_of::() + + distance.len() + + 1, + ); + output.extend_from_slice(distance.as_bytes()); + output.push(0); + output.extend_from_slice(&dimensions.to_be_bytes()); + output.extend_from_slice(&(items.serialized_size() as u32).to_be_bytes()); + items.serialize_into(&mut output)?; + output.extend_from_slice(entry_points.raw_bytes()); + output.push(*max_level); + + Ok(Cow::Owned(output)) + } +} + +impl<'a> heed3::BytesDecode<'a> for MetadataCodec { + type DItem = Metadata<'a>; + + fn bytes_decode(bytes: &'a [u8]) -> Result { + let distance = CStr::from_bytes_until_nul(bytes)?.to_str()?; + let bytes = &bytes[distance.len() + 1..]; + let dimensions = BigEndian::read_u32(bytes); + let bytes = &bytes[size_of::()..]; + let items_size = BigEndian::read_u32(bytes) as usize; + let bytes = &bytes[size_of::()..]; + let items = RoaringBitmap::deserialize_from(&bytes[..items_size])?; + let bytes = &bytes[items_size..]; + let entry_points = ItemIds::from_bytes(&bytes[..bytes.len() - 1]); + let max_level = bytes[bytes.len() - 1]; + + Ok(Metadata { + dimensions, + items, + distance, + entry_points, + max_level, + }) + } +} diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 279803d89..fa6dc31b4 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -1,7 +1,323 @@ -pub mod binary_heap; +use std::{borrow::Cow, cmp::Ordering}; + +use bincode::Options; +use byteorder::BE; +use hashbrown::HashMap; +use heed3::{ + Database, Env, Error as LmdbError, RoTxn, RwTxn, + types::{Bytes, U128}, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + helix_engine::{ + types::VectorError, + vector_core::{ + distance::{Cosine, Distance, DistanceValue}, + key::{Key, KeyCodec}, + node::{Item, NodeCodec}, + node_id::NodeMode, + reader::Reader, + unaligned_vector::UnalignedVector, + writer::Writer, + }, + }, + protocol::{ + custom_serde::vector_serde::{VectoWithoutDataDeSeed, VectorDeSeed}, + value::Value, + }, + utils::{id::v6_uuid, properties::ImmutablePropertiesMap}, +}; + +pub mod distance; pub mod hnsw; -pub mod utils; -pub mod vector; -pub mod vector_core; -pub mod vector_distance; -pub mod vector_without_data; +pub mod item_iter; +pub mod key; +pub mod metadata; +pub mod node; +pub mod node_id; +pub mod ordered_float; +pub mod parallel; +pub mod reader; +pub mod spaces; +pub mod stats; +pub mod unaligned_vector; +pub mod version; +pub mod writer; + +pub type ItemId = u32; + +pub type LayerId = u8; + +pub type VectorCoreResult = std::result::Result; + +pub type LmdbResult = std::result::Result; + +pub type CoreDatabase = heed3::Database>; + +#[derive(Debug, Serialize, Clone)] +pub struct HVector<'arena> { + pub id: u128, + pub distance: Option, + pub label: &'arena str, + pub deleted: bool, + pub version: u8, + pub level: usize, + pub properties: Option>, + pub data: Option>, +} + +impl<'arena> HVector<'arena> { + // FIXME: this allocates twice + pub fn data(&self, arena: &'arena bumpalo::Bump) -> &'arena [f64] { + let vec_f32 = self.data.as_ref().unwrap().vector.as_ref().to_vec(arena); + + arena.alloc_slice_fill_iter(vec_f32.iter().map(|&x| x as f64)) + } + + pub fn data_borrowed(&self) -> &[f64] { + bytemuck::cast_slice(self.data.as_ref().unwrap().vector.as_ref().as_bytes()) + } + + pub fn from_slice( + label: &'arena str, + level: usize, + data: &'arena [f64], + arena: &'arena bumpalo::Bump, + ) -> Self { + let id = v6_uuid(); + HVector { + id, + version: 1, + level, + label, + data: Some(Item::::from(data, arena)), + distance: None, + properties: None, + deleted: false, + } + } + + pub fn score(&self) -> f64 { + self.distance.unwrap_or(2.0) + } + + /// Converts HVector's data to a vec of bytes by accessing the data field directly + /// and converting each f64 to a byte slice + #[inline(always)] + pub fn vector_data_to_bytes(&self) -> VectorCoreResult<&[u8]> { + Ok(self.data.as_ref().unwrap().vector.as_ref().as_bytes()) + } + + /// Deserializes bytes into an vector using a custom deserializer that allocates into the provided arena + /// + /// Both the properties bytes (if present) and the raw vector data are combined to generate the final vector struct + /// + /// NOTE: in this method, fixint encoding is used + #[inline] + pub fn from_bincode_bytes<'txn>( + arena: &'arena bumpalo::Bump, + properties: Option<&'txn [u8]>, + raw_vector_data: &'txn [u8], + id: u128, + get_data: bool, + ) -> Result { + if get_data { + bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed( + VectorDeSeed { + arena, + id, + raw_vector_data, + }, + properties.unwrap_or(&[]), + ) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + }) + } else { + bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed( + VectoWithoutDataDeSeed { arena, id }, + properties.unwrap_or(&[]), + ) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + }) + } + } + + #[inline(always)] + pub fn to_bincode_bytes(&self) -> Result, bincode::Error> { + bincode::serialize(self) + } + + pub fn distance_to(&self, rhs: &HVector<'arena>) -> VectorCoreResult { + todo!() + } + + pub fn set_distance(&mut self, distance: f64) { + self.distance = Some(distance); + } + + pub fn get_distance(&self) -> f64 { + self.distance.unwrap() + } + + pub fn len(&self) -> usize { + self.data.as_ref().unwrap().vector.len() + } + + pub fn is_empty(&self) -> bool { + self.data.as_ref().unwrap().vector.is_empty() + } + + #[inline(always)] + pub fn get_property(&self, key: &str) -> Option<&'arena Value> { + self.properties.as_ref().and_then(|value| value.get(key)) + } + + pub fn cast_raw_vector_data<'txn>( + arena: &'arena bumpalo::Bump, + raw_vector_data: &'txn [u8], + ) -> &'txn [f64] { + todo!() + } + + pub fn from_raw_vector_data<'txn>( + arena: &'arena bumpalo::Bump, + raw_vector_data: &'txn [u8], + label: &'arena str, + id: u128, + ) -> Result { + todo!() + } +} + +impl PartialEq for HVector<'_> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} +impl Eq for HVector<'_> {} +impl PartialOrd for HVector<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for HVector<'_> { + fn cmp(&self, other: &Self) -> Ordering { + other + .distance + .partial_cmp(&self.distance) + .unwrap_or(Ordering::Equal) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HNSWConfig { + pub m: usize, // max num of bi-directional links per element + pub m_max_0: usize, // max num of links for lower layers + pub ef_construct: usize, // size of the dynamic candidate list for construction + pub m_l: f64, // level generation factor + pub ef: usize, // search param, num of cands to search + pub min_neighbors: usize, // for get_neighbors, always 512 +} + +impl HNSWConfig { + /// Constructor for the configs of the HNSW vector similarity search algorithm + /// - m (5 <= m <= 48): max num of bi-directional links per element + /// - m_max_0 (2 * m): max num of links for level 0 (level that stores all vecs) + /// - ef_construct (40 <= ef_construct <= 512): size of the dynamic candidate list + /// for construction + /// - m_l (ln(1/m)): level generation factor (multiplied by a random number) + /// - ef (10 <= ef <= 512): num of candidates to search + pub fn new(m: Option, ef_construct: Option, ef: Option) -> Self { + let m = m.unwrap_or(16).clamp(5, 48); + let ef_construct = ef_construct.unwrap_or(128).clamp(40, 512); + let ef = ef.unwrap_or(768).clamp(10, 512); + + Self { + m, + m_max_0: 2 * m, + ef_construct, + m_l: 1.0 / (m as f64).ln(), + ef, + min_neighbors: 512, + } + } +} + +pub struct VectorCoreStats { + // Do it atomical? + pub num_vectors: usize, +} + +// TODO: Properties filters +// TODO: Support different distances for each database +pub struct VectorCore { + /// One HNSW index per label + hsnw_index: HashMap>, + pub stats: VectorCoreStats, + pub vector_properties_db: Database, Bytes>, +} + +impl VectorCore { + pub fn new(env: &Env, txn: &mut RwTxn, config: HNSWConfig) -> VectorCoreResult { + todo!() + } + pub fn search_by_vector<'a>(&self, txn: &RoTxn, vector: &'a [f32]) {} + + pub fn search<'arena>( + &self, + txn: &RoTxn, + query: &'arena [f64], + k: usize, + label: &'arena str, + should_trickle: bool, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult>> { + todo!() + } + + pub fn insert<'arena>( + &self, + txn: &mut RwTxn, + label: &'arena str, + data: &'arena [f64], + properties: Option>, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult> { + todo!() + } + + pub fn delete(&self, txn: &RwTxn, id: u128, arena: &bumpalo::Bump) -> VectorCoreResult<()> { + Ok(()) + } + + pub fn get_full_vector<'arena>( + &self, + txn: &RoTxn, + id: u128, + arena: &bumpalo::Bump, + ) -> VectorCoreResult> { + todo!() + } + + pub fn get_vector_properties<'arena>( + &self, + txn: &RoTxn, + id: u128, + arena: &bumpalo::Bump, + ) -> VectorCoreResult>> { + todo!() + } + + pub fn num_inserted_vectors(&self) -> usize { + self.stats.num_vectors + } +} diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs new file mode 100644 index 000000000..b2a7948c5 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -0,0 +1,286 @@ +use core::fmt; +use std::{borrow::Cow, ops::Deref}; + +use bumpalo::collections::CollectIn; +use bytemuck::{bytes_of, cast_slice, pod_read_unaligned}; +use byteorder::{ByteOrder, NativeEndian}; +use heed3::{BoxedError, BytesDecode, BytesEncode}; +use roaring::RoaringBitmap; +use serde::Serialize; + +use crate::helix_engine::vector_core::{ + ItemId, distance::Distance, unaligned_vector::UnalignedVector, +}; + +#[derive(Clone, Debug)] +pub enum Node<'a, D: Distance> { + Item(Item<'a, D>), + Links(Links<'a>), +} + +const NODE_TAG: u8 = 0; +const LINKS_TAG: u8 = 1; + +impl<'a, D: Distance> Node<'a, D> { + pub fn item(self) -> Option> { + if let Node::Item(item) = self { + Some(item) + } else { + None + } + } + + pub fn links(self) -> Option> { + if let Node::Links(links) = self { + Some(links) + } else { + None + } + } +} + +/// An item node which corresponds to the vector inputed +/// by the user and the distance header. +#[derive(Serialize)] +pub struct Item<'a, D: Distance> { + /// The header of this item. + pub header: D::Header, + /// The vector of this item. + pub vector: Cow<'a, UnalignedVector>, +} + +impl fmt::Debug for Item<'_, D> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Item") + .field("header", &self.header) + .field("vector", &self.vector) + .finish() + } +} + +impl Clone for Item<'_, D> { + fn clone(&self) -> Self { + Self { + header: self.header, + vector: self.vector.clone(), + } + } +} + +impl Item<'_, D> { + /// Converts the item into an owned version of itself by cloning + /// the internal vector. Doing so will make it mutable. + pub fn into_owned(self) -> Item<'static, D> { + Item { + header: self.header, + vector: Cow::Owned(self.vector.into_owned()), + } + } + + /// Builds a new item from a `Vec`. + pub fn new(vec: bumpalo::collections::Vec) -> Self { + let vector = UnalignedVector::from_vec(vec); + let header = D::new_header(&vector); + Self { header, vector } + } + + pub fn from<'arena>(vec: &[f64], arena: &'arena bumpalo::Bump) -> Self { + Self::new(vec.into_iter().map(|x| *x as f32).collect_in(arena)) + } +} + +#[derive(Clone, Debug)] +pub struct Links<'a> { + pub links: Cow<'a, RoaringBitmap>, +} + +impl<'a> Deref for Links<'a> { + type Target = Cow<'a, RoaringBitmap>; + fn deref(&self) -> &Self::Target { + &self.links + } +} + +#[derive(Clone)] +pub struct ItemIds<'a> { + bytes: &'a [u8], +} + +impl<'a> ItemIds<'a> { + pub fn from_slice(slice: &[u32]) -> ItemIds<'_> { + ItemIds::from_bytes(cast_slice(slice)) + } + + pub fn from_bytes(bytes: &[u8]) -> ItemIds<'_> { + ItemIds { bytes } + } + + pub fn raw_bytes(&self) -> &[u8] { + self.bytes + } + + pub fn len(&self) -> usize { + self.bytes.len() / size_of::() + } + + pub fn iter(&self) -> impl Iterator + 'a { + self.bytes + .chunks_exact(size_of::()) + .map(NativeEndian::read_u32) + } +} + +impl fmt::Debug for ItemIds<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + self.iter().for_each(|integer| { + list.entry(&integer); + }); + list.finish() + } +} + +/// The codec used internally to encode and decode nodes. +pub struct NodeCodec(D); + +impl<'a, D: Distance> BytesEncode<'a> for NodeCodec { + type EItem = Node<'a, D>; + + fn bytes_encode(item: &Self::EItem) -> Result, BoxedError> { + let mut bytes = Vec::new(); + match item { + Node::Item(Item { header, vector }) => { + bytes.push(NODE_TAG); + bytes.extend_from_slice(bytes_of(header)); + bytes.extend(vector.as_bytes()); + } + Node::Links(Links { links }) => { + bytes.push(LINKS_TAG); + links.serialize_into(&mut bytes)?; + } + } + Ok(Cow::Owned(bytes)) + } +} + +impl<'a, D: Distance> BytesDecode<'a> for NodeCodec { + type DItem = Node<'a, D>; + + fn bytes_decode(bytes: &'a [u8]) -> Result { + match bytes { + [NODE_TAG, bytes @ ..] => { + let (header_bytes, remaining) = bytes.split_at(size_of::()); + let header = pod_read_unaligned(header_bytes); + let vector = UnalignedVector::::from_bytes(remaining)?; + + Ok(Node::Item(Item { header, vector })) + } + [LINKS_TAG, bytes @ ..] => { + let links: Cow<'_, RoaringBitmap> = + Cow::Owned(RoaringBitmap::deserialize_from(bytes).unwrap()); + Ok(Node::Links(Links { links })) + } + + [unknown_tag, ..] => Err(Box::new(InvalidNodeDecoding { + unknown_tag: Some(*unknown_tag), + })), + [] => Err(Box::new(InvalidNodeDecoding { unknown_tag: None })), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub struct InvalidNodeDecoding { + unknown_tag: Option, +} + +impl fmt::Display for InvalidNodeDecoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.unknown_tag { + Some(unknown_tag) => write!(f, "Invalid node decoding: unknown tag {unknown_tag}"), + None => write!(f, "Invalid node decoding: empty array of bytes"), + } + } +} + +#[cfg(test)] +mod tests { + use crate::helix_engine::vector_core::{ + distance::{Cosine, Distance}, + unaligned_vector::UnalignedVector, + }; + + use super::{Item, Links, Node, NodeCodec}; + use bumpalo::Bump; + use heed3::{BytesDecode, BytesEncode}; + use roaring::RoaringBitmap; + use std::borrow::Cow; + + #[test] + fn check_bytes_encode_decode() { + type D = Cosine; + + let b = Bump::new(); + let vector = UnalignedVector::from_vec(bumpalo::vec![in &b; 1.0, 2.0]); + let header = D::new_header(&vector); + let item = Item { vector, header }; + let db_item = Node::Item(item); + + let bytes = NodeCodec::::bytes_encode(&db_item); + assert!(bytes.is_ok()); + let bytes = bytes.unwrap(); + dbg!("{}, {}", std::mem::size_of_val(&db_item), bytes.len()); + // dbg!("{:?}", &bytes); + + let db_item2 = NodeCodec::::bytes_decode(bytes.as_ref()); + assert!(db_item2.is_ok()); + let db_item2 = db_item2.unwrap(); + + dbg!("{:?}", &db_item2); + dbg!("{:?}", &db_item); + } + + #[test] + fn test_codec() { + type D = Cosine; + + let b = Bump::new(); + let vector = UnalignedVector::from_vec(bumpalo::vec![in &b; 1.0, 2.0]); + let header = D::new_header(&vector); + let item = Item { vector, header }; + let db_item = Node::Item(item.clone()); + + let bytes = NodeCodec::::bytes_encode(&db_item); + assert!(bytes.is_ok()); + let bytes = bytes.unwrap(); + + let new_item = NodeCodec::::bytes_decode(bytes.as_ref()); + assert!(new_item.is_ok()); + let new_item = new_item.unwrap().item().unwrap(); + + assert!(matches!(new_item.vector, Cow::Borrowed(_))); + assert_eq!(new_item.vector.as_bytes(), item.vector.as_bytes()); + } + + #[test] + fn test_bitmap_codec() { + let mut bitmap = RoaringBitmap::new(); + bitmap.insert(1); + bitmap.insert(42); + + let links = Links { + links: Cow::Owned(bitmap), + }; + let db_item = Node::Links(links); + let bytes = NodeCodec::::bytes_encode(&db_item).unwrap(); + + let node = NodeCodec::::bytes_decode(&bytes).unwrap(); + assert!(matches!(node, Node::Links(_))); + let new_links = match node { + Node::Links(links) => links, + _ => unreachable!(), + }; + assert!(new_links.links.contains(1)); + assert!(new_links.links.contains(42)); + } +} diff --git a/helix-db/src/helix_engine/vector_core/node_id.rs b/helix-db/src/helix_engine/vector_core/node_id.rs new file mode 100644 index 000000000..024b8caf4 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/node_id.rs @@ -0,0 +1,160 @@ +use core::fmt; + +use byteorder::{BigEndian, ByteOrder}; + +use crate::helix_engine::vector_core::{ItemId, LayerId}; + +/// /!\ Changing the value of the enum can be DB-breaking /!\ +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum NodeMode { + /// Stores the metadata under the `ItemId` 0 + Metadata = 0, + /// Stores the list of all the `ItemId` that have been updated. + /// We only stores `Unit` values under the keys. + Updated = 1, + /// The graph edges re stored under this id + Links = 2, + /// The original vectors are stored under this id in `Item` structures. + Item = 3, +} + +impl TryFrom for NodeMode { + type Error = String; + + fn try_from(v: u8) -> std::result::Result { + match v { + v if v == NodeMode::Item as u8 => Ok(NodeMode::Item), + v if v == NodeMode::Links as u8 => Ok(NodeMode::Links), + v if v == NodeMode::Updated as u8 => Ok(NodeMode::Updated), + v if v == NodeMode::Metadata as u8 => Ok(NodeMode::Metadata), + v => Err(format!("Could not convert {v} as a `NodeMode`.")), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct NodeId { + /// Indicate what the item represent. + pub mode: NodeMode, + /// The item we want to get. + pub item: ItemId, + /// Store Hnsw layer ID after ItemId for co-locality of (vec, its_links) in lmdb (?) + /// Safe to store in a u8 since impossible the graph will have >256 layers + pub layer: LayerId, +} + +impl fmt::Debug for NodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}({},{})", self.mode, self.item, self.layer) + } +} + +impl NodeId { + pub const fn metadata() -> Self { + Self { + mode: NodeMode::Metadata, + item: 0, + layer: 0, + } + } + + pub const fn version() -> Self { + Self { + mode: NodeMode::Metadata, + item: 1, + layer: 0, + } + } + + pub const fn updated(item: u32) -> Self { + Self { + mode: NodeMode::Updated, + item, + layer: 0, + } + } + + pub const fn links(item: u32, layer: u8) -> Self { + Self { + mode: NodeMode::Links, + item, + layer, + } + } + + pub const fn item(item: u32) -> Self { + Self { + mode: NodeMode::Item, + item, + layer: 0, + } + } + + /// Return the underlying `ItemId` if it is an item. + /// Panic otherwise. + #[track_caller] + pub fn unwrap_item(&self) -> ItemId { + assert_eq!(self.mode, NodeMode::Item); + self.item + } + + /// Return the underlying `ItemId` if it is a links node. + /// Panic otherwise. + #[track_caller] + pub fn unwrap_node(&self) -> (ItemId, LayerId) { + assert_eq!(self.mode, NodeMode::Links); + (self.item, self.layer) + } + + pub fn to_bytes(self) -> [u8; 6] { + let mut output = [0; 6]; + + output[0] = self.mode as u8; + output[1] = self.layer; + let item_bytes = self.item.to_be_bytes(); + output[2..=5].copy_from_slice(&item_bytes); + + output + } + + pub fn from_bytes(bytes: &[u8]) -> (Self, &[u8]) { + let mode = NodeMode::try_from(bytes[0]).expect("Could not parse the node mode"); + let layer = bytes[1]; + let item = BigEndian::read_u32(&bytes[2..]); + + ( + Self { mode, item, layer }, + &bytes[size_of::() + size_of::()..], + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn check_node_id_ordering() { + // NOTE: `layer`s take precedence over item_ids + assert!(NodeId::item(0) == NodeId::item(0)); + assert!(NodeId::item(1) > NodeId::item(0)); + assert!(NodeId::item(0) < NodeId::item(1)); + + assert!(NodeId::links(0, 0) == NodeId::links(0, 0)); + assert!(NodeId::links(1, 0) > NodeId::links(0, 0)); + assert!(NodeId::links(0, 1) > NodeId::links(0, 0)); + assert!(NodeId::links(1, 0) > NodeId::links(0, 1)); + + assert!(NodeId::updated(0) == NodeId::updated(0)); + assert!(NodeId::updated(1) > NodeId::updated(0)); + assert!(NodeId::updated(0) < NodeId::updated(1)); + + assert!(NodeId::links(u32::MAX, 0) < NodeId::item(0)); + + assert!(NodeId::metadata() == NodeId::metadata()); + assert!(NodeId::metadata() < NodeId::links(u32::MIN, 0)); + assert!(NodeId::metadata() < NodeId::updated(u32::MIN)); + assert!(NodeId::metadata() < NodeId::item(u32::MIN)); + } +} diff --git a/helix-db/src/helix_engine/vector_core/ordered_float.rs b/helix-db/src/helix_engine/vector_core/ordered_float.rs new file mode 100644 index 000000000..641094dc6 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/ordered_float.rs @@ -0,0 +1,47 @@ +/// A wrapper type around f32s implementing `Ord` +/// +/// Since distance metrics satisfy d(x,x)=0 and d(x,y)>0 for x!=y we don't need to operate on the +/// full range of f32's. Comparing the u32 representation of a non-negative f32 should suffice and +/// is actually a lot quicker. +/// +/// https://en.wikipedia.org/wiki/IEEE_754-1985#NaN +#[derive(Default, Debug, Clone, Copy)] +pub struct OrderedFloat(pub f32); + +impl PartialEq for OrderedFloat { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits().eq(&other.0.to_bits()) + } +} + +impl Eq for OrderedFloat {} + +impl PartialOrd for OrderedFloat { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrderedFloat { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.to_bits().cmp(&other.0.to_bits()) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use crate::helix_engine::vector_core::ordered_float::OrderedFloat; + + proptest! { + #[test] + fn ordering_makes_sense( + (upper, lower) in (0.0f32..=f32::MAX).prop_flat_map(|u|{ + (Just(u), 0.0f32..=u) + }) + ){ + assert!(OrderedFloat(upper) > OrderedFloat(lower)); + } + } +} diff --git a/helix-db/src/helix_engine/vector_core/parallel.rs b/helix-db/src/helix_engine/vector_core/parallel.rs new file mode 100644 index 000000000..a12d63a6f --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/parallel.rs @@ -0,0 +1,172 @@ +use core::slice; +use std::borrow::Cow; +use std::marker; + +use hashbrown::HashMap; +use heed3::types::Bytes; +use heed3::{BytesDecode, RoTxn}; +use roaring::RoaringBitmap; +use rustc_hash::FxBuildHasher; + +use crate::helix_engine::vector_core::distance::Distance; +use crate::helix_engine::vector_core::key::{KeyCodec, Prefix, PrefixCodec}; +use crate::helix_engine::vector_core::node::{Item, Links, Node, NodeCodec}; +use crate::helix_engine::vector_core::{CoreDatabase, ItemId, LayerId, LmdbResult}; + +/// A struture used to keep a list of the item nodes in the graph. +/// +/// It is safe to share between threads as the pointer are pointing +/// in the mmapped file and the transaction is kept here and therefore +/// no longer touches the database. +pub struct ImmutableItems<'t, D> { + items: HashMap, + constant_length: Option, + _marker: marker::PhantomData<(&'t (), D)>, +} + +// NOTE: this previously took an arg `items: &RoaringBitmap` which corresponded to the `to_insert`. +// When building the hnsw in multiple dumps we need vecs from previous dumps in order to "glue" +// things together. +// To accomodate this we use a cursor over ALL Key::items in the db. +impl<'t, D: Distance> ImmutableItems<'t, D> { + /// Creates the structure by fetching all the item vector pointers + /// and keeping the transaction making the pointers valid. + /// Do not take more items than memory allows. + /// Remove from the list of candidates all the items that were selected and return them. + pub fn new(rtxn: &'t RoTxn, database: CoreDatabase, index: u16) -> LmdbResult { + let mut map = + HashMap::with_capacity_and_hasher(database.len(rtxn)? as usize, FxBuildHasher); + let mut constant_length = None; + + let cursor = database + .remap_types::() + .prefix_iter(rtxn, &Prefix::item(index))? + .remap_key_type::(); + + for res in cursor { + let (item_id, bytes) = res?; + assert_eq!(*constant_length.get_or_insert(bytes.len()), bytes.len()); + let ptr = bytes.as_ptr(); + map.insert(item_id.node.item, ptr); + } + + Ok(ImmutableItems { + items: map, + constant_length, + _marker: marker::PhantomData, + }) + } + + /// Returns the items identified by the given ID. + pub fn get(&self, item_id: ItemId) -> LmdbResult>> { + let len = match self.constant_length { + Some(len) => len, + None => return Ok(None), + }; + let ptr = match self.items.get(&item_id) { + Some(ptr) => *ptr, + None => return Ok(None), + }; + + // safety: + // - ptr: The pointer comes from LMDB. Since the database cannot be written to, it is still valid. + // - len: All the items share the same dimensions and are the same size + let bytes = unsafe { slice::from_raw_parts(ptr, len) }; + NodeCodec::bytes_decode(bytes) + .map_err(heed3::Error::Decoding) + .map(|node| node.item()) + } +} + +unsafe impl Sync for ImmutableItems<'_, D> {} + +/// A struture used to keep a list of all the links. +/// It is safe to share between threads as the pointers are pointing +/// in the mmapped file and the transaction is kept here and therefore +/// no longer touches the database. +pub struct ImmutableLinks<'t, D> { + links: HashMap<(u32, u8), (usize, *const u8), FxBuildHasher>, + _marker: marker::PhantomData<(&'t (), D)>, +} + +impl<'t, D: Distance> ImmutableLinks<'t, D> { + /// Creates the structure by fetching all the root pointers + /// and keeping the transaction making the pointers valid. + pub fn new( + rtxn: &'t RoTxn, + database: CoreDatabase, + index: u16, + nb_links: u64, + ) -> LmdbResult { + let mut links = HashMap::with_capacity_and_hasher(nb_links as usize, FxBuildHasher); + + let iter = database + .remap_types::() + .prefix_iter(rtxn, &Prefix::links(index))? + .remap_key_type::(); + + for result in iter { + let (key, bytes) = result?; + let links_id = key.node.unwrap_node(); + links.insert(links_id, (bytes.len(), bytes.as_ptr())); + } + + Ok(ImmutableLinks { + links, + _marker: marker::PhantomData, + }) + } + + /// Returns the node identified by the given ID. + pub fn get(&self, item_id: ItemId, level: LayerId) -> LmdbResult>> { + let key = (item_id, level); + let (ptr, len) = match self.links.get(&key) { + Some((len, ptr)) => (*ptr, *len), + None => return Ok(None), + }; + + // safety: + // - ptr: The pointer comes from LMDB. Since the database cannot be written to, it is still valid. + // - len: The len cannot change either + let bytes = unsafe { slice::from_raw_parts(ptr, len) }; + NodeCodec::bytes_decode(bytes) + .map_err(heed3::Error::Decoding) + .map(|node: Node<'t, D>| node.links()) + } + + pub fn iter(&self) -> impl Iterator)>> { + self.links.keys().map(|&k| { + let (item_id, level) = k; + match self.get(item_id, level) { + Ok(Some(Links { links })) => Ok((k, links)), + Ok(None) => { + unreachable!("link at level {level} with item_id {item_id} not found") + } + Err(e) => Err(e), + } + }) + } + + /// `Iter`s only over links in a given level + pub fn iter_layer( + &self, + layer: u8, + ) -> impl Iterator)>> { + self.links.keys().filter_map(move |&k| { + let (item_id, level) = k; + if level != layer { + return None; + } + + match self.get(item_id, level) { + Ok(Some(Links { links })) => Some(Ok((k, links))), + Ok(None) => { + unreachable!("link at level {level} with item_id {item_id} not found") + } + Err(e) => Some(Err(e)), + } + }) + } +} + +unsafe impl Sync for ImmutableLinks<'_, D> {} diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs new file mode 100644 index 000000000..dad45417a --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -0,0 +1,754 @@ +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::marker; +use std::num::NonZeroUsize; + +use bumpalo::collections::CollectIn; +use heed3::RoTxn; +use heed3::types::Bytes; +use heed3::types::DecodeIgnore; +use min_max_heap::MinMaxHeap; +use roaring::RoaringBitmap; +use tracing::warn; + +use crate::helix_engine::vector_core::VectorCoreResult; +use crate::helix_engine::vector_core::VectorError; +use crate::helix_engine::vector_core::distance::Distance; +use crate::helix_engine::vector_core::distance::DistanceValue; +use crate::helix_engine::vector_core::hnsw::ScoredLink; +use crate::helix_engine::vector_core::item_iter::ItemIter; +use crate::helix_engine::vector_core::key::{Key, KeyCodec, Prefix, PrefixCodec}; +#[cfg(not(windows))] +use crate::helix_engine::vector_core::metadata::Metadata; +use crate::helix_engine::vector_core::metadata::MetadataCodec; +use crate::helix_engine::vector_core::node::Node; +use crate::helix_engine::vector_core::node::{Item, Links}; +use crate::helix_engine::vector_core::ordered_float::OrderedFloat; +use crate::helix_engine::vector_core::unaligned_vector::{UnalignedVector, VectorCodec}; +use crate::helix_engine::vector_core::version::{Version, VersionCodec}; +use crate::helix_engine::vector_core::{CoreDatabase, ItemId}; + +/// A good default value for the `ef` parameter. +const DEFAULT_EF_SEARCH: usize = 100; + +#[cfg(not(windows))] +const READER_AVAILABLE_MEMORY: &str = "HANNOY_READER_PREFETCH_MEMORY"; + +#[cfg(not(test))] +/// The threshold at which linear search is used instead of the HNSW algorithm. +const LINEAR_SEARCH_THRESHOLD: u64 = 1000; +#[cfg(test)] +/// Note that for tests purposes, we use set this threshold +/// to zero to make sure we test the HNSW algorithm. +const LINEAR_SEARCH_THRESHOLD: u64 = 0; + +/// Container storing nearest neighbour search result +#[derive(Debug)] +pub struct Searched<'arena> { + /// The nearest neighbours for the performed query + pub nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>, +} + +impl<'arena> Searched<'arena> { + pub(crate) fn new(nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>) -> Self { + Searched { nns } + } + + /// Consumes `self` and returns vector of nearest neighbours + pub fn into_nns(self) -> bumpalo::collections::Vec<'arena, (ItemId, f32)> { + self.nns + } +} + +/// Options used to make a query against an hannoy [`Reader`]. +pub struct QueryBuilder<'a, D: Distance> { + reader: &'a Reader, + candidates: Option<&'a RoaringBitmap>, + count: usize, + ef: usize, +} + +impl<'a, D: Distance> QueryBuilder<'a, D> { + pub fn by_item<'arena>( + &self, + rtxn: &RoTxn, + item: ItemId, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult>> { + let res = self + .reader + .nns_by_item(rtxn, item, self, arena)? + .map(|res| Searched::new(res)); + Ok(res) + } + + pub fn by_vector<'arena>( + &self, + rtxn: &RoTxn, + vector: &'a [f32], + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult> { + if vector.len() != self.reader.dimensions() { + return Err(VectorError::InvalidVecDimension { + expected: self.reader.dimensions(), + received: vector.len(), + }); + } + + let vector = UnalignedVector::from_slice(vector); + let item = Item { + header: D::new_header(&vector), + vector, + }; + + let neighbours = self.reader.nns_by_vec(rtxn, &item, self, arena)?; + + Ok(Searched::new(neighbours)) + } + + /// Specify a subset of candidates to inspect. Filters out everything else. + /// + /// # Examples + /// + /// ```no_run + /// # use hannoy::{Reader, distances::Euclidean}; + /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); + /// let candidates = roaring::RoaringBitmap::from_iter([1, 3, 4, 5, 6, 7, 8, 9, 15, 16]); + /// reader.nns(20).candidates(&candidates).by_item(&rtxn, 6); + /// ``` + pub fn candidates(&mut self, candidates: &'a RoaringBitmap) -> &mut Self { + self.candidates = Some(candidates); + self + } + + /// Specify a search buffer size from which the closest elements are returned. Increasing this + /// value improves the search relevancy but increases latency as more neighbours need to be + /// searched. + /// In an ideal graph `ef`=`count` would suffice. + /// + /// # Examples + /// + /// ```no_run + /// # use hannoy::{Reader, distances::Euclidean}; + /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); + /// reader.nns(20).ef_search(21).by_item(&rtxn, 6); + /// ``` + pub fn ef_search(&mut self, ef: usize) -> &mut Self { + self.ef = ef.max(self.count); + self + } +} + +struct Visitor<'a> { + pub eps: Vec, + pub level: usize, + pub ef: usize, + pub candidates: Option<&'a RoaringBitmap>, +} +impl<'a> Visitor<'a> { + pub fn new( + eps: Vec, + level: usize, + ef: usize, + candidates: Option<&'a RoaringBitmap>, + ) -> Self { + Self { + eps, + level, + ef, + candidates, + } + } + + /// Iteratively traverse a given level of the HNSW graph, updating the search path history. + /// Returns a Min-Max heap of size ef nearest neighbours to the query in that layer. + #[allow(clippy::too_many_arguments)] + pub fn visit( + &self, + query: &Item, + reader: &Reader, + rtxn: &RoTxn, + path: &mut RoaringBitmap, + ) -> VectorCoreResult> { + let mut search_queue = BinaryHeap::new(); + let mut res = MinMaxHeap::with_capacity(self.ef); + + // Register all entry points as visited and populate candidates + for &ep in &self.eps[..] { + let ve = get_item(reader.database, reader.index, rtxn, ep)?.unwrap(); + let dist = D::distance(query, &ve); + + search_queue.push((Reverse(OrderedFloat(dist)), ep)); + path.insert(ep); + + if self.candidates.is_none_or(|c| c.contains(ep)) { + res.push((OrderedFloat(dist), ep)); + } + } + + // Stop occurs either once we've done at least ef searches and notice no improvements, or + // when we've exhausted the search queue. + while let Some(&(Reverse(OrderedFloat(f)), _)) = search_queue.peek() { + let f_max = res + .peek_max() + .map(|&(OrderedFloat(d), _)| d) + .unwrap_or(f32::MAX); + if f > f_max { + break; + } + let (_, c) = search_queue.pop().unwrap(); + + let Links { links } = get_links(rtxn, reader.database, reader.index, c, self.level)? + .expect("Links must exist"); + + for point in links.iter() { + if !path.insert(point) { + continue; + } + let dist = D::distance( + query, + &get_item(reader.database, reader.index, rtxn, point)?.unwrap(), + ); + + // The search queue can take points that aren't included in the (optional) + // candidates bitmap, but the final result must *not* include them. + if res.len() < self.ef || dist < f_max { + search_queue.push((Reverse(OrderedFloat(dist)), point)); + if let Some(c) = self.candidates { + if !c.contains(point) { + continue; + } + } + if res.len() == self.ef { + let _ = res.push_pop_max((OrderedFloat(dist), point)); + } else { + res.push((OrderedFloat(dist), point)); + } + } + } + } + + Ok(res) + } +} + +/// A reader over the hannoy hnsw graph +#[derive(Debug)] +pub struct Reader { + pub(crate) database: CoreDatabase, + pub(crate) index: u16, + entry_points: Vec, + max_level: usize, + dimensions: usize, + items: RoaringBitmap, + version: Version, + _marker: marker::PhantomData, +} + +impl Reader { + /// Returns a reader over the database with the specified [`Distance`] type. + pub fn open( + rtxn: &RoTxn, + index: u16, + database: CoreDatabase, + ) -> VectorCoreResult> { + let metadata_key = Key::metadata(index); + + let metadata = match database + .remap_data_type::() + .get(rtxn, &metadata_key)? + { + Some(metadata) => metadata, + None => return Err(VectorError::MissingMetadata(index)), + }; + let version = match database + .remap_data_type::() + .get(rtxn, &Key::version(index))? + { + Some(version) => version, + None => Version { + major: 0, + minor: 0, + patch: 0, + }, + }; + + if D::name() != metadata.distance { + return Err(VectorError::UnmatchingDistance { + expected: metadata.distance.to_owned(), + received: D::name(), + }); + } + + // check if we need to rebuild + if database + .remap_types::() + .prefix_iter(rtxn, &Prefix::updated(index))? + .remap_key_type::() + .next() + .is_some() + { + return Err(VectorError::NeedBuild(index)); + } + + // Hint to the kernel that we'll probably need some vectors in RAM. + Self::prefetch_graph(rtxn, &database, index, &metadata)?; + + Ok(Reader { + database: database.remap_data_type(), + index, + entry_points: Vec::from_iter(metadata.entry_points.iter()), + max_level: metadata.max_level as usize, + dimensions: metadata.dimensions.try_into().unwrap(), + items: metadata.items, + version, + _marker: marker::PhantomData, + }) + } + + #[cfg(windows)] + fn prefetch_graph( + _rtxn: &RoTxn, + _database: &CoreDatabase, + _index: u16, + _metadata: &Metadata, + ) -> Result<()> { + // madvise crate does not support windows. + Ok(()) + } + + /// Instructs kernel to fetch nodes based on a fixed memory budget. It's OK for this operation + /// to fail, it's not integral for search to work. + #[cfg(not(windows))] + fn prefetch_graph( + rtxn: &RoTxn, + database: &CoreDatabase, + index: u16, + metadata: &Metadata, + ) -> VectorCoreResult<()> { + use std::{collections::VecDeque, sync::atomic::AtomicUsize}; + + let page_size = page_size::get(); + let mut available_memory: usize = std::env::var(READER_AVAILABLE_MEMORY) + .ok() + .and_then(|num| num.parse::().ok()) + .unwrap_or(0); + + if available_memory < page_size { + return Ok(()); + } + + let largest_alloc = AtomicUsize::new(0); + + // adjusted length in memory of a vector + let item_length = (metadata.dimensions as usize).div_ceil(::word_size()); + + let madvise_page = |item: &[u8]| -> VectorCoreResult { + use std::sync::atomic::Ordering; + + let start_ptr = item.as_ptr() as usize; + let end_ptr = start_ptr + item_length; + let start_page = start_ptr - (start_ptr % page_size); + let end_page = end_ptr + ((end_ptr + page_size - 1) % page_size); + let advised_size = end_page - start_page; + + unsafe { + use madvise::AccessPattern; + + madvise::madvise( + start_page as *const u8, + advised_size, + AccessPattern::WillNeed, + )?; + } + + largest_alloc.fetch_max(advised_size, Ordering::Relaxed); + Ok(advised_size) + }; + + // Load links and vectors for layers > 0. + let mut added = RoaringBitmap::new(); + for lvl in (1..=metadata.max_level).rev() { + use heed3::types::Bytes; + + for result in database.remap_data_type::().iter(rtxn)? { + use std::sync::atomic::Ordering; + + if available_memory < largest_alloc.load(Ordering::Relaxed) { + return Ok(()); + } + let (key, item) = result?; + if key.node.layer != lvl { + continue; + } + match madvise_page(item) { + Ok(usage) => available_memory -= usage, + Err(e) => { + use tracing::warn; + + warn!(e=?e); + return Ok(()); + } + } + added.insert(key.node.item); + } + } + + // If we still have memory left over try fetching other nodes in layer zero. + let mut queue = VecDeque::from_iter(added.iter()); + while let Some(item) = queue.pop_front() { + use std::sync::atomic::Ordering; + + use crate::helix_engine::vector_core::node::Node; + + if available_memory < largest_alloc.load(Ordering::Relaxed) { + return Ok(()); + } + if let Some(Node::Links(links)) = database.get(rtxn, &Key::links(index, item, 0))? { + for l in links.iter() { + if !added.insert(l) { + continue; + } + if let Some(bytes) = database + .remap_data_type::() + .get(rtxn, &Key::item(index, l))? + { + match madvise_page(bytes) { + Ok(usage) => available_memory -= usage, + Err(e) => { + warn!(e=?e); + return Ok(()); + } + } + queue.push_back(l); + } + } + } + } + + Ok(()) + } + + /// Returns the number of dimensions in the index. + pub fn dimensions(&self) -> usize { + self.dimensions + } + + /// Returns the number of entry points to the hnsw index. + pub fn n_entrypoints(&self) -> usize { + self.entry_points.len() + } + + /// Returns the number of vectors stored in the index. + pub fn n_items(&self) -> u64 { + self.items.len() + } + + /// Returns all the item ids contained in this index. + pub fn item_ids(&self) -> &RoaringBitmap { + &self.items + } + + /// Returns the index of this reader in the database. + pub fn index(&self) -> u16 { + self.index + } + + /// Returns the version of the database. + pub fn version(&self) -> Version { + self.version + } + + /// Returns the number of nodes in the index. Useful to run an exhaustive search. + pub fn n_nodes(&self, rtxn: &RoTxn) -> VectorCoreResult> { + Ok(NonZeroUsize::new(self.database.len(rtxn)? as usize)) + } + + /// Returns the vector for item `i` that was previously added. + pub fn item_vector<'arena>( + &self, + rtxn: &RoTxn, + item_id: ItemId, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult>> { + Ok( + get_item(self.database, self.index, rtxn, item_id)?.map(|item| { + let mut vec = item.vector.to_vec(&arena); + vec.truncate(self.dimensions()); + vec + }), + ) + } + + /// Returns `true` if the index is empty. + pub fn is_empty(&self, rtxn: &RoTxn, arena: &bumpalo::Bump) -> VectorCoreResult { + self.iter(rtxn, arena).map(|mut iter| iter.next().is_none()) + } + + /// Returns `true` if the database contains the given item. + pub fn contains_item(&self, rtxn: &RoTxn, item_id: ItemId) -> VectorCoreResult { + self.database + .remap_data_type::() + .get(rtxn, &Key::item(self.index, item_id)) + .map(|opt| opt.is_some()) + .map_err(Into::into) + } + + /// Returns an iterator over the items vector. + pub fn iter<'t>( + &self, + rtxn: &'t RoTxn, + arena: &'t bumpalo::Bump, + ) -> VectorCoreResult> { + ItemIter::new(self.database, self.index, self.dimensions, rtxn, arena).map_err(Into::into) + } + + /// Return a [`QueryBuilder`] that lets you configure and execute a search request. + /// + /// You must provide the number of items you want to receive. + pub fn nns(&self, count: usize) -> QueryBuilder<'_, D> { + QueryBuilder { + reader: self, + candidates: None, + count, + ef: DEFAULT_EF_SEARCH, + } + } + + fn nns_by_vec<'arena>( + &self, + rtxn: &RoTxn, + query: &Item, + opt: &QueryBuilder, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult> { + // If we will never find any candidates, return an empty vector + if opt + .candidates + .is_some_and(|c| self.item_ids().is_disjoint(c)) + { + return Ok(bumpalo::collections::Vec::new_in(&arena)); + } + + // If the number of candidates is less than a given threshold, perform linear search + if let Some(candidates) = opt.candidates.filter(|c| c.len() < LINEAR_SEARCH_THRESHOLD) { + return self.brute_force_search(query, rtxn, candidates, opt.count, &arena); + } + + // exhaustive search + self.hnsw_search(query, rtxn, opt, &arena) + } + + /// Directly retrieves items in the candidate list and ranks them by distance to the query. + fn brute_force_search<'arena>( + &self, + query: &Item, + rtxn: &RoTxn, + candidates: &RoaringBitmap, + count: usize, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult> { + let mut item_distances = + bumpalo::collections::Vec::with_capacity_in(candidates.len() as usize, &arena); + + for item_id in candidates { + let Some(vector) = self.item_vector(rtxn, item_id, arena)? else { + continue; + }; + let vector = UnalignedVector::from_vec(vector); + let item = Item { + header: D::new_header(&vector), + vector, + }; + let distance = D::distance(&item, query); + item_distances.push((item_id, distance)); + } + item_distances.sort_by_key(|(_, dist)| OrderedFloat(*dist)); + item_distances.truncate(count); + + Ok(item_distances) + } + + /// Hnsw search according to arXiv:1603.09320. + /// + /// We perform greedy beam search from the top layer to the bottom, where the search frontier + /// is controlled by `opt.ef`. Since the graph is not necessarily acyclic, search may become + /// "trapped" in a local sub-graph with fewer elements than `opt.count` - to account for this + /// we run an expensive exhaustive search at the end if fewer nns were returned. + /// + /// To break out of search early, users may wish to provide a `cancel_fn` which terminates the + /// execution of the hnsw search and returns partial results so far. + fn hnsw_search<'arena>( + &self, + query: &Item, + rtxn: &RoTxn, + opt: &QueryBuilder, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult> { + let mut visitor = Visitor::new(self.entry_points.clone(), self.max_level, 1, None); + + let mut path = RoaringBitmap::new(); + for _ in (1..=self.max_level).rev() { + let neighbours = visitor.visit(query, self, rtxn, &mut path)?; + let closest = neighbours + .peek_min() + .map(|(_, n)| n) + .expect("No neighbor was found"); + + visitor.eps = vec![*closest]; + visitor.level -= 1; + } + + // clear visited set as we only care about level 0 + path.clear(); + debug_assert!(visitor.level == 0); + + visitor.ef = opt.ef.max(opt.count); + visitor.candidates = opt.candidates; + + let mut neighbours = visitor.visit(query, self, rtxn, &mut path)?; + + // If we still don't have enough nns (e.g. search encountered cyclic subgraphs) then do exhaustive + // search over remaining unseen items. + if neighbours.len() < opt.count { + let mut cursor = self + .database + .remap_types::() + .prefix_iter(rtxn, &Prefix::item(self.index))? + .remap_key_type::(); + + while let Some((key, _)) = cursor.next().transpose()? { + let id = key.node.item; + if path.contains(id) { + continue; + } + + visitor.eps = vec![id]; + visitor.ef = opt.count - neighbours.len(); + + let more_nns = visitor.visit(query, self, rtxn, &mut path)?; + + neighbours.extend(more_nns.into_iter()); + if neighbours.len() >= opt.count { + break; + } + } + } + + Ok(neighbours + .drain_asc() + .map(|(OrderedFloat(f), i)| (i, f)) + .take(opt.count) + .collect_in(arena)) + } + + /// Returns the nearest points to the item id, not including the point itself. + /// + /// Nearly identical behaviour to `Reader.nns_by_vec` except we only search layer 0 and use the + /// `&[item]` instead of the hnsw entrypoints. Since search starts in the true neighbourhood of + /// the item fewer comparisons are needed to retrieve the nearest neighbours, making it more + /// efficient than simply calling `Reader.nns_by_vec` with the associated vector. + #[allow(clippy::type_complexity)] + fn nns_by_item<'arena>( + &self, + rtxn: &RoTxn, + item: ItemId, + opt: &QueryBuilder, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult>> { + // If we will never find any candidates, return none + if opt + .candidates + .is_some_and(|c| self.item_ids().is_disjoint(c)) + { + return Ok(None); + } + + let Some(vector) = self.item_vector(rtxn, item, arena)? else { + return Ok(None); + }; + let vector = UnalignedVector::from_vec(vector); + let query = Item { + header: D::new_header(&vector), + vector, + }; + + // If the number of candidates is less than a given threshold, perform linear search + if let Some(candidates) = opt.candidates.filter(|c| c.len() < LINEAR_SEARCH_THRESHOLD) { + let nns = self.brute_force_search(&query, rtxn, candidates, opt.count, arena)?; + return Ok(Some(nns)); + } + + // Search over all items except `item` + let ef = opt.ef.max(opt.count); + let mut path = RoaringBitmap::new(); + let mut candidates = opt.candidates.unwrap_or_else(|| self.item_ids()).clone(); + candidates.remove(item); + + let mut visitor = Visitor::new(vec![item], 0, ef, Some(&candidates)); + + let mut neighbours = visitor.visit(&query, self, rtxn, &mut path)?; + + // If we still don't have enough nns (e.g. search encountered cyclic subgraphs) then do exhaustive + // search over remaining unseen items. + if neighbours.len() < opt.count { + let mut cursor = self + .database + .remap_types::() + .prefix_iter(rtxn, &Prefix::item(self.index))? + .remap_key_type::(); + + while let Some((key, _)) = cursor.next().transpose()? { + let id = key.node.item; + if path.contains(id) { + continue; + } + + // update walker + visitor.eps = vec![id]; + visitor.ef = opt.count - neighbours.len(); + + let more_nns = visitor.visit(&query, self, rtxn, &mut path)?; + neighbours.extend(more_nns.into_iter()); + if neighbours.len() >= opt.count { + break; + } + } + } + + Ok(Some( + neighbours + .drain_asc() + .map(|(OrderedFloat(f), i)| (i, f)) + .take(opt.count) + .collect_in(arena), + )) + } +} + +pub fn get_item<'a, D: Distance>( + database: CoreDatabase, + index: u16, + rtxn: &'a RoTxn, + item: ItemId, +) -> VectorCoreResult>> { + match database.get(rtxn, &Key::item(index, item))? { + Some(Node::Item(item)) => Ok(Some(item)), + Some(Node::Links(_)) => Ok(None), + None => Ok(None), + } +} + +pub fn get_links<'a, D: Distance>( + rtxn: &'a RoTxn, + database: CoreDatabase, + index: u16, + item_id: ItemId, + level: usize, +) -> VectorCoreResult>> { + match database.get(rtxn, &Key::links(index, item_id, level as u8))? { + Some(Node::Links(links)) => Ok(Some(links)), + Some(Node::Item(_)) => Ok(None), + None => Ok(None), + } +} diff --git a/helix-db/src/helix_engine/vector_core/spaces/mod.rs b/helix-db/src/helix_engine/vector_core/spaces/mod.rs new file mode 100644 index 000000000..15009da1d --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/spaces/mod.rs @@ -0,0 +1,10 @@ +pub mod simple; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod simple_sse; + +#[cfg(target_arch = "x86_64")] +mod simple_avx; + +#[cfg(target_arch = "aarch64")] +mod simple_neon; diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple.rs b/helix-db/src/helix_engine/vector_core/spaces/simple.rs new file mode 100644 index 000000000..98c78e218 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/spaces/simple.rs @@ -0,0 +1,84 @@ +use crate::helix_engine::vector_core::unaligned_vector::UnalignedVector; + +#[cfg(target_arch = "x86_64")] +use super::simple_avx::*; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +use super::simple_neon::*; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use super::simple_sse::*; + +#[cfg(target_arch = "x86_64")] +const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +const MIN_DIM_SIZE_SIMD: usize = 16; + +pub fn euclidean_distance(u: &UnalignedVector, v: &UnalignedVector) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && u.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { euclid_similarity_avx(u, v) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && u.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { euclid_similarity_sse(u, v) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && u.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { euclid_similarity_neon(u, v) }; + } + } + + euclidean_distance_non_optimized(u, v) +} + +// Don't use dot-product: avoid catastrophic cancellation in +// https://github.com/spotify/annoy/issues/314. +pub fn euclidean_distance_non_optimized(u: &UnalignedVector, v: &UnalignedVector) -> f32 { + u.iter().zip(v.iter()).map(|(u, v)| (u - v) * (u - v)).sum() +} + +pub fn dot_product(u: &UnalignedVector, v: &UnalignedVector) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && u.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { dot_similarity_avx(u, v) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && u.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { dot_similarity_sse(u, v) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && u.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { dot_similarity_neon(u, v) }; + } + } + + dot_product_non_optimized(u, v) +} + +pub fn dot_product_non_optimized(u: &UnalignedVector, v: &UnalignedVector) -> f32 { + u.iter().zip(v.iter()).map(|(a, b)| a * b).sum() +} diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs new file mode 100644 index 000000000..720b1211e --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs @@ -0,0 +1,163 @@ +use std::arch::x86_64::*; +use std::ptr::read_unaligned; + +use crate::helix_engine::vector_core::unaligned_vector::UnalignedVector; + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +unsafe fn hsum256_ps_avx(x: __m256) -> f32 { + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn euclid_similarity_avx( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub256_1: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); + sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); + + let sub256_2: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); + sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); + + let sub256_3: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); + sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); + + let sub256_4: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); + sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn dot_similarity_avx( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1); + sum256_2 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(8)), + _mm256_loadu_ps(ptr2.add(8)), + sum256_2, + ); + sum256_3 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(16)), + _mm256_loadu_ps(ptr2.add(16)), + sum256_3, + ); + sum256_4 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(24)), + _mm256_loadu_ps(ptr2.add(24)), + sum256_4, + ); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result +} + +#[cfg(test)] +mod tests { + use crate::helix_engine::vector_core::spaces::simple::*; + + #[test] + fn test_spaces_avx() { + use super::*; + + if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let v1 = UnalignedVector::from_slice(&v1[..]); + let v2 = UnalignedVector::from_slice(&v2[..]); + + let euclid_simd = unsafe { euclid_similarity_avx(&v1, &v2) }; + let euclid = euclidean_distance_non_optimized(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) }; + let dot = dot_product_non_optimized(&v1, &v2); + assert_eq!(dot_simd, dot); + + // let cosine_simd = unsafe { cosine_preprocess_avx(v1.clone()) }; + // let cosine = cosine_preprocess(v1); + // assert_eq!(cosine_simd, cosine); + } else { + println!("avx test skipped"); + } + } +} diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs new file mode 100644 index 000000000..a176ce11e --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs @@ -0,0 +1,154 @@ +#[cfg(target_feature = "neon")] +use crate::unaligned_vector::UnalignedVector; +use std::arch::aarch64::*; +use std::ptr::read_unaligned; + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn euclid_similarity_neon( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types + // from potentially unaligned memory locations safely. + // https://github.com/meilisearch/arroy/pull/13 + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let sub1 = vsubq_f32(unaligned_float32x4_t(ptr1), unaligned_float32x4_t(ptr2)); + sum1 = vfmaq_f32(sum1, sub1, sub1); + + let sub2 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(4)), + unaligned_float32x4_t(ptr2.add(4)), + ); + sum2 = vfmaq_f32(sum2, sub2, sub2); + + let sub3 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(8)), + unaligned_float32x4_t(ptr2.add(8)), + ); + sum3 = vfmaq_f32(sum3, sub3, sub3); + + let sub4 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(12)), + unaligned_float32x4_t(ptr2.add(12)), + ); + sum4 = vfmaq_f32(sum4, sub4, sub4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result +} + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn dot_similarity_neon( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types + // from potentially unaligned memory locations safely. + // https://github.com/meilisearch/arroy/pull/13 + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + sum1 = vfmaq_f32( + sum1, + unaligned_float32x4_t(ptr1), + unaligned_float32x4_t(ptr2), + ); + sum2 = vfmaq_f32( + sum2, + unaligned_float32x4_t(ptr1.add(4)), + unaligned_float32x4_t(ptr2.add(4)), + ); + sum3 = vfmaq_f32( + sum3, + unaligned_float32x4_t(ptr1.add(8)), + unaligned_float32x4_t(ptr2.add(8)), + ); + sum4 = vfmaq_f32( + sum4, + unaligned_float32x4_t(ptr1.add(12)), + unaligned_float32x4_t(ptr2.add(12)), + ); + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result +} + +/// Reads 4xf32 in a stack-located array aligned on a f32 and reads a `float32x4_t` from it. +unsafe fn unaligned_float32x4_t(ptr: *const f32) -> float32x4_t { + vld1q_f32(read_unaligned(ptr as *const [f32; 4]).as_ptr()) +} + +#[cfg(test)] +mod tests { + #[cfg(target_feature = "neon")] + #[test] + fn test_spaces_neon() { + use super::*; + use crate::spaces::simple::*; + + if std::arch::is_aarch64_feature_detected!("neon") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., + ]; + + let v1 = UnalignedVector::from_slice(&v1[..]); + let v2 = UnalignedVector::from_slice(&v2[..]); + + let euclid_simd = unsafe { euclid_similarity_neon(&v1, &v2) }; + let euclid = euclidean_distance_non_optimized(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) }; + let dot = dot_product_non_optimized(&v1, &v2); + assert_eq!(dot_simd, dot); + + // let cosine_simd = unsafe { cosine_preprocess_neon(v1.clone()) }; + // let cosine = cosine_preprocess(v1); + // assert_eq!(cosine_simd, cosine); + } else { + println!("neon test skipped"); + } + } +} diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs new file mode 100644 index 000000000..06ad0fa09 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs @@ -0,0 +1,158 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; +use std::ptr::read_unaligned; + +use crate::helix_engine::vector_core::unaligned_vector::UnalignedVector; + +#[target_feature(enable = "sse")] +unsafe fn hsum128_ps_sse(x: __m128) -> f32 { + let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn euclid_similarity_sse( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); + sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); + + let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); + sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); + + let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); + sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); + + let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); + sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn dot_similarity_sse( + v1: &UnalignedVector, + v2: &UnalignedVector, +) -> f32 { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1); + + sum128_2 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))), + sum128_2, + ); + + sum128_3 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))), + sum128_3, + ); + + sum128_4 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))), + sum128_4, + ); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result +} + +#[cfg(test)] +mod tests { + use crate::helix_engine::vector_core::spaces::simple::{ + dot_product_non_optimized, euclidean_distance_non_optimized, + }; + + #[test] + fn test_spaces_sse() { + use super::*; + + if is_x86_feature_detected!("sse") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let v1 = UnalignedVector::from_slice(&v1[..]); + let v2 = UnalignedVector::from_slice(&v2[..]); + + let euclid_simd = unsafe { euclid_similarity_sse(&v1, &v2) }; + let euclid = euclidean_distance_non_optimized(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_sse(&v1, &v2) }; + let dot = dot_product_non_optimized(&v1, &v2); + assert_eq!(dot_simd, dot); + + // let cosine_simd = unsafe { cosine_preprocess_sse(v1.clone()) }; + // let cosine = cosine_preprocess(v1); + // assert_eq!(cosine_simd, cosine); + } else { + println!("sse test skipped"); + } + } +} diff --git a/helix-db/src/helix_engine/vector_core/stats.rs b/helix-db/src/helix_engine/vector_core/stats.rs new file mode 100644 index 000000000..ef652ed98 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/stats.rs @@ -0,0 +1,84 @@ +use std::marker::PhantomData; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use hashbrown::HashMap; +use heed3::{Result, RoTxn}; + +use crate::helix_engine::vector_core::CoreDatabase; +use crate::helix_engine::vector_core::distance::Distance; +use crate::helix_engine::vector_core::key::{KeyCodec, Prefix, PrefixCodec}; +use crate::helix_engine::vector_core::node::{Links, Node}; + +// TODO: ignore the phantom +#[derive(Debug)] +pub(crate) struct BuildStats { + /// a counter to see how many times `HnswBuilder.add_link` is invoked + pub n_links_added: AtomicUsize, + /// a counter tracking how many times we hit lmdb + pub lmdb_hits: AtomicUsize, + /// average rank of a node, calculated at the end of build + pub mean_degree: f32, + /// number of elements per layer + pub layer_dist: HashMap, + /// track some race condition violations + pub link_misses: AtomicUsize, + + _phantom: PhantomData, +} + +impl BuildStats { + pub fn new() -> BuildStats { + BuildStats { + n_links_added: AtomicUsize::new(0), + lmdb_hits: AtomicUsize::new(0), + mean_degree: 0.0, + layer_dist: HashMap::default(), + link_misses: AtomicUsize::new(0), + _phantom: PhantomData, + } + } + + pub fn incr_link_count(&self, val: usize) { + self.n_links_added.fetch_add(val, Ordering::Relaxed); + } + + pub fn incr_lmdb_hits(&self) { + self.lmdb_hits.fetch_add(1, Ordering::Relaxed); + } + + pub fn incr_link_misses(&self) { + self.link_misses.fetch_add(1, Ordering::Relaxed); + } + + /// iterate over all links in db and average out node rank + pub fn compute_mean_degree( + &mut self, + rtxn: &RoTxn, + db: &CoreDatabase, + index: u16, + ) -> Result<()> { + let iter = db + .remap_key_type::() + .prefix_iter(rtxn, &Prefix::links(index))? + .remap_key_type::(); + + let mut n_links = 0; + let mut total_links = 0; + + for res in iter { + let (_key, node) = res?; + + let links = match node { + Node::Links(Links { links }) => links, + Node::Item(_) => unreachable!("Node must not be an item"), + }; + + total_links += links.len(); + n_links += 1; + } + + self.mean_degree = (total_links as f32) / (n_links as f32); + + Ok(()) + } +} diff --git a/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs b/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs new file mode 100644 index 000000000..4ffbbe910 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs @@ -0,0 +1,70 @@ +use std::{ + borrow::Cow, + mem::{size_of, transmute}, +}; + +use bytemuck::cast_slice; +use byteorder::{ByteOrder, NativeEndian}; + +use super::{SizeMismatch, UnalignedVector, VectorCodec}; + +impl VectorCodec for f32 { + /// Creates an unaligned slice of f32 wrapper from a slice of bytes. + fn from_bytes(bytes: &[u8]) -> Result>, SizeMismatch> { + let rem = bytes.len() % size_of::(); + if rem == 0 { + // safety: `UnalignedF32Slice` is transparent + Ok(Cow::Borrowed(unsafe { + transmute::<&[u8], &UnalignedVector>(bytes) + })) + } else { + Err(SizeMismatch { + vector_codec: "f32", + rem, + }) + } + } + + /// Creates an unaligned slice of f32 wrapper from a slice of f32. + /// The slice is already known to be of the right length. + fn from_slice(slice: &[f32]) -> Cow<'_, UnalignedVector> { + Self::from_bytes(cast_slice(slice)).unwrap() + } + + /// Creates an unaligned slice of f32 wrapper from a slice of f32. + /// The slice is already known to be of the right length. + fn from_vec<'arena>( + vec: bumpalo::collections::Vec<'arena, f32>, + ) -> Cow<'static, UnalignedVector> { + let bytes = vec.into_iter().flat_map(|f| f.to_ne_bytes()).collect(); + Cow::Owned(bytes) + } + + // todo: add arena + fn to_vec<'arena>( + vec: &UnalignedVector, + arena: &'arena bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, f32> { + let iter = vec.iter(); + let mut ret = bumpalo::collections::Vec::with_capacity_in(iter.len(), &arena); + ret.extend(iter); + ret + } + + /// Returns an iterator of f32 that are read from the slice. + /// The f32 are copied in memory and are therefore, aligned. + fn iter(vec: &UnalignedVector) -> impl ExactSizeIterator + '_ { + vec.vector + .chunks_exact(size_of::()) + .map(NativeEndian::read_f32) + } + + /// Return the number of f32 that fits into this slice. + fn len(vec: &UnalignedVector) -> usize { + vec.vector.len() / size_of::() + } + + fn is_zero(vec: &UnalignedVector) -> bool { + vec.iter().all(|v| v == 0.0) + } +} diff --git a/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs b/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs new file mode 100644 index 000000000..55688983a --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs @@ -0,0 +1,182 @@ +use std::{ + borrow::{Borrow, Cow}, + fmt, + marker::PhantomData, + mem::transmute, +}; + +use bytemuck::pod_collect_to_vec; +use serde::Serialize; + +mod f32; + +/// Determine the way the vectors should be read and written from the database +pub trait VectorCodec: std::borrow::ToOwned + Sized { + /// Creates an unaligned vector from a slice of bytes. + /// Don't allocate. + fn from_bytes(bytes: &[u8]) -> Result>, SizeMismatch>; + + /// Creates an unaligned vector from a slice of f32. + /// May allocate depending on the codec. + fn from_slice(slice: &[f32]) -> Cow<'_, UnalignedVector>; + + /// Creates an unaligned slice of f32 wrapper from a slice of f32. + /// The slice is already known to be of the right length. + fn from_vec<'arena>( + vec: bumpalo::collections::Vec<'arena, f32>, + ) -> Cow<'static, UnalignedVector>; + + /// Converts the `UnalignedVector` to an aligned vector of `f32`. + /// It's strictly equivalent to `.iter().collect()` but the performances + /// are better. + fn to_vec<'arena>( + vec: &UnalignedVector, + arena: &'arena bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, f32>; + + /// Returns an iterator of f32 that are read from the vector. + /// The f32 are copied in memory and are therefore, aligned. + fn iter(vec: &UnalignedVector) -> impl ExactSizeIterator + '_; + + /// Returns the len of the vector in terms of elements. + fn len(vec: &UnalignedVector) -> usize; + + /// Returns true if all the elements in the vector are equal to 0. + fn is_zero(vec: &UnalignedVector) -> bool; + + /// Returns the bit-packing size if quantized + fn word_size() -> usize { + 1 + } +} + +/// A wrapper struct that is used to read unaligned vectors directly from memory. +#[repr(transparent)] +#[derive(Serialize)] +pub struct UnalignedVector { + format: PhantomData Codec>, + vector: [u8], +} + +impl UnalignedVector { + /// Creates an unaligned vector from a slice of bytes. + /// Don't allocate. + pub fn from_bytes(bytes: &[u8]) -> Result>, SizeMismatch> { + Codec::from_bytes(bytes) + } + + /// Creates an unaligned vector from a slice of f32. + /// May allocate depending on the codec. + pub fn from_slice(slice: &[f32]) -> Cow<'_, UnalignedVector> { + Codec::from_slice(slice) + } + + /// Creates an unaligned slice of f32 wrapper from a slice of f32. + /// The slice is already known to be of the right length. + pub fn from_vec<'arena>( + vec: bumpalo::collections::Vec<'arena, f32>, + ) -> Cow<'static, UnalignedVector> { + Codec::from_vec(vec) + } + + /// Returns an iterator of f32 that are read from the vector. + /// The f32 are copied in memory and are therefore, aligned. + pub fn iter(&self) -> impl ExactSizeIterator + '_ { + Codec::iter(self) + } + + /// Returns true if all the elements in the vector are equal to 0. + pub fn is_zero(&self) -> bool { + Codec::is_zero(self) + } + + /// Returns an allocated and aligned `Vec`. + pub fn to_vec<'arena>( + &self, + arena: &'arena bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, f32> { + Codec::to_vec(self, &arena) + } + + /// Returns the len of the vector in terms of elements. + pub fn len(&self) -> usize { + Codec::len(self) + } + + /// Creates an unaligned slice of something. It's up to the caller to ensure + /// it will be used with the same type it was created initially. + pub(crate) fn from_bytes_unchecked(bytes: &[u8]) -> &Self { + unsafe { transmute(bytes) } + } + + /// Returns the original raw slice of bytes. + pub(crate) fn as_bytes(&self) -> &[u8] { + &self.vector + } + + /// Returns wether it is empty or not. + pub fn is_empty(&self) -> bool { + self.vector.is_empty() + } + /// Returns the raw pointer to the start of this slice. + pub(crate) fn as_ptr(&self) -> *const u8 { + self.vector.as_ptr() + } +} + +/// Returned in case you tried to make an unaligned vector from a slice of bytes that don't have the right number of elements +#[derive(Debug, thiserror::Error)] +#[error( + "Slice of bytes contains {rem} too many bytes to be decoded with the {vector_codec} codec." +)] +pub struct SizeMismatch { + /// The name of the codec used. + vector_codec: &'static str, + /// The number of bytes remaining after decoding as many words as possible. + rem: usize, +} + +impl ToOwned for UnalignedVector { + type Owned = Vec; + + fn to_owned(&self) -> Self::Owned { + pod_collect_to_vec(&self.vector) + } +} + +impl Borrow> for Vec { + fn borrow(&self) -> &UnalignedVector { + UnalignedVector::from_bytes_unchecked(self) + } +} + +impl fmt::Debug for UnalignedVector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + + struct Number(f32); + impl fmt::Debug for Number { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:0.4}", self.0) + } + } + + let arena = bumpalo::Bump::new(); + let vec = self.to_vec(&arena); + for v in vec.iter().take(10) { + list.entry(&Number(*v)); + } + if vec.len() < 10 { + return list.finish(); + } + + // With binary quantization we may be padding with a lot of zeros + if vec[10..].iter().all(|v| *v == 0.0) { + list.entry(&"0.0, ..."); + } else { + list.entry(&"other ..."); + } + + list.finish() + } +} diff --git a/helix-db/src/helix_engine/vector_core/utils.rs b/helix-db/src/helix_engine/vector_core/utils.rs deleted file mode 100644 index f7704f512..000000000 --- a/helix-db/src/helix_engine/vector_core/utils.rs +++ /dev/null @@ -1,167 +0,0 @@ -use super::binary_heap::BinaryHeap; -use crate::helix_engine::{ - traversal_core::LMDB_STRING_HEADER_LENGTH, - types::VectorError, - vector_core::{vector::HVector, vector_without_data::VectorWithoutData}, -}; -use heed3::{ - Database, RoTxn, - byteorder::BE, - types::{Bytes, U128}, -}; -use std::cmp::Ordering; - -#[derive(PartialEq)] -pub(super) struct Candidate { - pub id: u128, - pub distance: f64, -} - -impl Eq for Candidate {} - -impl PartialOrd for Candidate { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Candidate { - fn cmp(&self, other: &Self) -> Ordering { - other - .distance - .partial_cmp(&self.distance) - .unwrap_or(Ordering::Equal) - } -} - -pub(super) trait HeapOps<'a, T> { - /// Take the top k elements from the heap - /// Used because using `.iter()` does not keep the order - fn take_inord(&mut self, k: usize) -> BinaryHeap<'a, T> - where - T: Ord; - - /// Get the maximum element from the heap - fn get_max<'q>(&'q self) -> Option<&'a T> - where - T: Ord, - 'q: 'a; -} - -impl<'a, T> HeapOps<'a, T> for BinaryHeap<'a, T> { - #[inline(always)] - fn take_inord(&mut self, k: usize) -> BinaryHeap<'a, T> - where - T: Ord, - { - let mut result = BinaryHeap::with_capacity(self.arena, k); - for _ in 0..k { - if let Some(item) = self.pop() { - result.push(item); - } else { - break; - } - } - result - } - - #[inline(always)] - fn get_max<'q>(&'q self) -> Option<&'a T> - where - T: Ord, - 'q: 'a, - { - self.iter().max() - } -} - -pub trait VectorFilter<'db, 'arena, 'txn, 'q> { - fn to_vec_with_filter( - self, - k: usize, - filter: Option<&'arena [F]>, - label: &'arena str, - txn: &'txn RoTxn<'db>, - db: Database, Bytes>, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &'txn RoTxn<'db>) -> bool; -} - -impl<'db, 'arena, 'txn, 'q> VectorFilter<'db, 'arena, 'txn, 'q> - for BinaryHeap<'arena, HVector<'arena>> -{ - #[inline(always)] - fn to_vec_with_filter( - mut self, - k: usize, - filter: Option<&'arena [F]>, - label: &'arena str, - txn: &'txn RoTxn<'db>, - db: Database, Bytes>, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &'txn RoTxn<'db>) -> bool, - { - let mut result = bumpalo::collections::Vec::with_capacity_in(k, arena); - for _ in 0..k { - // while pop check filters and pop until one passes - while let Some(mut item) = self.pop() { - let properties = match db.get(txn, &item.id)? { - Some(bytes) => { - // println!("decoding"); - let res = Some(VectorWithoutData::from_bincode_bytes( - arena, bytes, item.id, - )?); - // println!("decoded: {res:?}"); - res - } - None => None, // TODO: maybe should be an error? - }; - - if let Some(properties) = properties - && SHOULD_CHECK_DELETED - && properties.deleted - { - continue; - } - - if item.label() == label - && (filter.is_none() || filter.unwrap().iter().all(|f| f(&item, txn))) - { - assert!( - properties.is_some(), - "properties should be some, otherwise there has been an error on vector insertion as properties are always inserted" - ); - item.expand_from_vector_without_data(properties.unwrap()); - result.push(item); - break; - } - } - } - - Ok(result) - } -} - -pub fn check_deleted(data: &[u8]) -> bool { - assert!( - data.len() >= LMDB_STRING_HEADER_LENGTH, - "value length does not contain header which means the `label` field was missing from the node on insertion" - ); - let length_of_label_in_lmdb = - u64::from_le_bytes(data[..LMDB_STRING_HEADER_LENGTH].try_into().unwrap()) as usize; - - let length_of_version_in_lmdb = 1; - - let deleted_index = - LMDB_STRING_HEADER_LENGTH + length_of_label_in_lmdb + length_of_version_in_lmdb; - - assert!( - data.len() >= deleted_index, - "data length is not at least the deleted index plus the length of the deleted field meaning there has been a corruption on node insertion" - ); - data[deleted_index] == 1 -} diff --git a/helix-db/src/helix_engine/vector_core/vector.rs b/helix-db/src/helix_engine/vector_core/vector.rs deleted file mode 100644 index 30c3223c5..000000000 --- a/helix-db/src/helix_engine/vector_core/vector.rs +++ /dev/null @@ -1,305 +0,0 @@ -use crate::{ - helix_engine::{ - types::VectorError, - vector_core::{vector_distance::DistanceCalc, vector_without_data::VectorWithoutData}, - }, - protocol::{custom_serde::vector_serde::VectorDeSeed, value::Value}, - utils::{ - id::{uuid_str_from_buf, v6_uuid}, - properties::ImmutablePropertiesMap, - }, -}; -use bincode::Options; -use core::fmt; -use serde::{Serialize, Serializer, ser::SerializeMap}; -use std::{alloc, cmp::Ordering, fmt::Debug, mem, ptr, slice}; - -// TODO: make this generic over the type of encoding (f32, f64, etc) -// TODO: use const param to set dimension -// TODO: set level as u8 - -#[repr(C, align(16))] // TODO: see performance impact of repr(C) and align(16) -#[derive(Clone, Copy)] -pub struct HVector<'arena> { - /// The id of the HVector - pub id: u128, - /// The label of the HVector - pub label: &'arena str, - /// the version of the vector - pub version: u8, - /// whether the vector is deleted - pub deleted: bool, - /// The level of the HVector - pub level: usize, - /// The distance of the HVector - pub distance: Option, - /// The actual vector - pub data: &'arena [f64], - /// The properties of the HVector - pub properties: Option>, -} - -impl<'arena> Serialize for HVector<'arena> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - use serde::ser::SerializeStruct; - - // Check if this is a human-readable format (like JSON) - if serializer.is_human_readable() { - // Include id for JSON serialization - let mut buffer = [0u8; 36]; - let mut state = serializer.serialize_map(Some( - 5 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0), - ))?; - state.serialize_entry("id", uuid_str_from_buf(self.id, &mut buffer))?; - state.serialize_entry("label", &self.label)?; - state.serialize_entry("version", &self.version)?; - state.serialize_entry("deleted", &self.deleted)?; - if let Some(properties) = &self.properties { - for (key, value) in properties.iter() { - state.serialize_entry(key, value)?; - } - } - state.end() - } else { - // Skip id, level, distance, and data for bincode serialization - let mut state = serializer.serialize_struct("HVector", 4)?; - state.serialize_field("label", &self.label)?; - state.serialize_field("version", &self.version)?; - state.serialize_field("deleted", &self.deleted)?; - state.serialize_field("properties", &self.properties)?; - state.end() - } - } -} - -impl PartialEq for HVector<'_> { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} -impl Eq for HVector<'_> {} -impl PartialOrd for HVector<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for HVector<'_> { - fn cmp(&self, other: &Self) -> Ordering { - other - .distance - .partial_cmp(&self.distance) - .unwrap_or(Ordering::Equal) - } -} - -impl Debug for HVector<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{{ \nid: {},\nlevel: {},\ndistance: {:?},\ndata: {:?}, }}", - uuid::Uuid::from_u128(self.id), - // self.is_deleted, - self.level, - self.distance, - self.data, - ) - } -} - -impl<'arena> HVector<'arena> { - #[inline(always)] - pub fn from_slice(label: &'arena str, level: usize, data: &'arena [f64]) -> Self { - let id = v6_uuid(); - HVector { - id, - // is_deleted: false, - version: 1, - level, - label, - data, - distance: None, - properties: None, - deleted: false, - } - } - - /// Converts the HVector to an vec of bytes by accessing the data field directly - /// and converting each f64 to a byte slice - #[inline(always)] - pub fn vector_data_to_bytes(&self) -> Result<&[u8], VectorError> { - bytemuck::try_cast_slice(self.data).map_err(|_| { - VectorError::ConversionError("Invalid vector data: vector data".to_string()) - }) - } - - /// Deserializes bytes into an vector using a custom deserializer that allocates into the provided arena - /// - /// Both the properties bytes (if present) and the raw vector data are combined to generate the final vector struct - /// - /// NOTE: in this method, fixint encoding is used - #[inline] - pub fn from_bincode_bytes<'txn>( - arena: &'arena bumpalo::Bump, - properties: Option<&'txn [u8]>, - raw_vector_data: &'txn [u8], - id: u128, - ) -> Result { - bincode::options() - .with_fixint_encoding() - .allow_trailing_bytes() - .deserialize_seed( - VectorDeSeed { - arena, - id, - raw_vector_data, - }, - properties.unwrap_or(&[]), - ) - .map_err(|e| VectorError::ConversionError(format!("Error deserializing vector: {e}"))) - } - - #[inline(always)] - pub fn to_bincode_bytes(&self) -> Result, bincode::Error> { - bincode::serialize(self) - } - - /// Casts the raw bytes to a f64 slice by copying them once into the arena - #[inline] - pub fn cast_raw_vector_data<'txn>( - arena: &'arena bumpalo::Bump, - raw_vector_data: &'txn [u8], - ) -> &'arena [f64] { - assert!(!raw_vector_data.is_empty(), "raw_vector_data.len() == 0"); - assert!( - raw_vector_data.len().is_multiple_of(mem::size_of::()), - "raw_vector_data bytes len is not a multiple of size_of::()" - ); - let dimensions = raw_vector_data.len() / mem::size_of::(); - - assert!( - raw_vector_data.len().is_multiple_of(dimensions), - "raw_vector_data does not have the exact required number of dimensions" - ); - - let layout = alloc::Layout::array::(dimensions) - .expect("vector_data array arithmetic overflow or total size exceeds isize::MAX"); - - let vector_data: ptr::NonNull = arena.alloc_layout(layout); - - // 'arena because the destination pointer is allocated in the arena - let data: &'arena [f64] = unsafe { - // SAFETY: - // - We assert data is present and that we are within bounds in asserts above - ptr::copy_nonoverlapping( - raw_vector_data.as_ptr(), - vector_data.as_ptr(), - raw_vector_data.len(), - ); - - // We allocated with the layout of an f64 array - let vector_data: ptr::NonNull = vector_data.cast(); - - // SAFETY: - // - `vector_data`` is guaranteed to be valid by being NonNull - // - the asserts above guarantee that there are enough valid bytes to be read - slice::from_raw_parts(vector_data.as_ptr(), dimensions) - }; - - data - } - - /// Uses just the vector data to generate a HVector struct - pub fn from_raw_vector_data<'txn>( - arena: &'arena bumpalo::Bump, - raw_vector_data: &'txn [u8], - label: &'arena str, - id: u128, - ) -> Result { - let data = Self::cast_raw_vector_data(arena, raw_vector_data); - Ok(HVector { - id, - label, - data, - version: 1, - level: 0, - distance: None, - properties: None, - deleted: false, - }) - } - - #[inline(always)] - pub fn len(&self) -> usize { - self.data.len() - } - - #[inline(always)] - pub fn is_empty(&self) -> bool { - self.data.is_empty() - } - - #[inline(always)] - pub fn distance_to(&self, other: &HVector) -> Result { - HVector::<'arena>::distance(self, other) - } - - #[inline(always)] - pub fn set_distance(&mut self, distance: f64) { - self.distance = Some(distance); - } - - #[inline(always)] - pub fn get_distance(&self) -> f64 { - self.distance.unwrap_or(2.0) - } - - #[inline(always)] - pub fn get_label(&self) -> Option<&Value> { - match &self.properties { - Some(p) => p.get("label"), - None => None, - } - } - - #[inline(always)] - pub fn get_property(&self, key: &str) -> Option<&'arena Value> { - self.properties.as_ref().and_then(|value| value.get(key)) - } - - pub fn id(&self) -> &u128 { - &self.id - } - - pub fn label(&self) -> &'arena str { - self.label - } - - pub fn score(&self) -> f64 { - self.distance.unwrap_or(2.0) - } - - pub fn expand_from_vector_without_data(&mut self, vector: VectorWithoutData<'arena>) { - self.label = vector.label; - self.version = vector.version; - self.level = vector.level; - self.properties = vector.properties; - } -} - -impl<'arena> From> for HVector<'arena> { - fn from(value: VectorWithoutData<'arena>) -> Self { - HVector { - id: value.id, - label: value.label, - version: value.version, - level: value.level, - distance: None, - data: &[], - properties: value.properties, - deleted: value.deleted, - } - } -} diff --git a/helix-db/src/helix_engine/vector_core/vector_core.rs b/helix-db/src/helix_engine/vector_core/vector_core.rs deleted file mode 100644 index 495212583..000000000 --- a/helix-db/src/helix_engine/vector_core/vector_core.rs +++ /dev/null @@ -1,664 +0,0 @@ -use super::binary_heap::BinaryHeap; -use crate::{ - debug_println, - helix_engine::{ - types::VectorError, - vector_core::{ - hnsw::HNSW, - utils::{Candidate, HeapOps, VectorFilter}, - vector::HVector, - vector_without_data::VectorWithoutData, - }, - }, - utils::{id::uuid_str, properties::ImmutablePropertiesMap}, -}; -use heed3::{ - Database, Env, RoTxn, RwTxn, - byteorder::BE, - types::{Bytes, U128, Unit}, -}; -use rand::prelude::Rng; -use serde::{Deserialize, Serialize}; -use std::collections::HashSet; - -const DB_VECTORS: &str = "vectors"; // for vector data (v:) -const DB_VECTOR_DATA: &str = "vector_data"; // for vector data (v:) -const DB_HNSW_EDGES: &str = "hnsw_out_nodes"; // for hnsw out node data -const VECTOR_PREFIX: &[u8] = b"v:"; -pub const ENTRY_POINT_KEY: &[u8] = b"entry_point"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HNSWConfig { - pub m: usize, // max num of bi-directional links per element - pub m_max_0: usize, // max num of links for lower layers - pub ef_construct: usize, // size of the dynamic candidate list for construction - pub m_l: f64, // level generation factor - pub ef: usize, // search param, num of cands to search - pub min_neighbors: usize, // for get_neighbors, always 512 -} - -impl HNSWConfig { - /// Constructor for the configs of the HNSW vector similarity search algorithm - /// - m (5 <= m <= 48): max num of bi-directional links per element - /// - m_max_0 (2 * m): max num of links for level 0 (level that stores all vecs) - /// - ef_construct (40 <= ef_construct <= 512): size of the dynamic candidate list - /// for construction - /// - m_l (ln(1/m)): level generation factor (multiplied by a random number) - /// - ef (10 <= ef <= 512): num of candidates to search - pub fn new(m: Option, ef_construct: Option, ef: Option) -> Self { - let m = m.unwrap_or(16).clamp(5, 48); - let ef_construct = ef_construct.unwrap_or(128).clamp(40, 512); - let ef = ef.unwrap_or(768).clamp(10, 512); - - Self { - m, - m_max_0: 2 * m, - ef_construct, - m_l: 1.0 / (m as f64).ln(), - ef, - min_neighbors: 512, - } - } -} - -pub struct VectorCore { - pub vectors_db: Database, - pub vector_properties_db: Database, Bytes>, - pub edges_db: Database, - pub config: HNSWConfig, -} - -impl VectorCore { - pub fn new(env: &Env, txn: &mut RwTxn, config: HNSWConfig) -> Result { - let vectors_db = env.create_database(txn, Some(DB_VECTORS))?; - let vector_properties_db = env - .database_options() - .types::, Bytes>() - .name(DB_VECTOR_DATA) - .create(txn)?; - let edges_db = env.create_database(txn, Some(DB_HNSW_EDGES))?; - - Ok(Self { - vectors_db, - vector_properties_db, - edges_db, - config, - }) - } - - /// Vector key: [v, id, ] - #[inline(always)] - pub fn vector_key(id: u128, level: usize) -> Vec { - [VECTOR_PREFIX, &id.to_be_bytes(), &level.to_be_bytes()].concat() - } - - #[inline(always)] - pub fn out_edges_key(source_id: u128, level: usize, sink_id: Option) -> Vec { - match sink_id { - Some(sink_id) => [ - source_id.to_be_bytes().as_slice(), - level.to_be_bytes().as_slice(), - sink_id.to_be_bytes().as_slice(), - ] - .concat() - .to_vec(), - None => [ - source_id.to_be_bytes().as_slice(), - level.to_be_bytes().as_slice(), - ] - .concat() - .to_vec(), - } - } - - #[inline] - fn get_new_level(&self) -> usize { - let mut rng = rand::rng(); - let r: f64 = rng.random::(); - (-r.ln() * self.config.m_l).floor() as usize - } - - #[inline] - fn get_entry_point<'db: 'arena, 'arena: 'txn, 'txn>( - &self, - txn: &'txn RoTxn<'db>, - label: &'arena str, - arena: &'arena bumpalo::Bump, - ) -> Result, VectorError> { - let ep_id = self.vectors_db.get(txn, ENTRY_POINT_KEY)?; - if let Some(ep_id) = ep_id { - let mut arr = [0u8; 16]; - let len = std::cmp::min(ep_id.len(), 16); - arr[..len].copy_from_slice(&ep_id[..len]); - - let ep = self - .get_raw_vector_data(txn, u128::from_be_bytes(arr), label, arena) - .map_err(|_| VectorError::EntryPointNotFound)?; - Ok(ep) - } else { - Err(VectorError::EntryPointNotFound) - } - } - - #[inline] - fn set_entry_point(&self, txn: &mut RwTxn, entry: &HVector) -> Result<(), VectorError> { - self.vectors_db - .put(txn, ENTRY_POINT_KEY, &entry.id.to_be_bytes()) - .map_err(VectorError::from)?; - Ok(()) - } - - #[inline(always)] - pub fn put_vector<'arena>( - &self, - txn: &mut RwTxn, - vector: &HVector<'arena>, - ) -> Result<(), VectorError> { - self.vectors_db - .put( - txn, - &Self::vector_key(vector.id, vector.level), - vector.vector_data_to_bytes()?, - ) - .map_err(VectorError::from)?; - self.vector_properties_db - .put(txn, &vector.id, &bincode::serialize(&vector)?)?; - Ok(()) - } - - #[inline(always)] - fn get_neighbors<'db: 'arena, 'arena: 'txn, 'txn, F>( - &self, - txn: &'txn RoTxn<'db>, - label: &'arena str, - id: u128, - level: usize, - filter: Option<&[F]>, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - { - let out_key = Self::out_edges_key(id, level, None); - let mut neighbors = bumpalo::collections::Vec::with_capacity_in( - self.config.m_max_0.min(self.config.min_neighbors), - arena, - ); - - let iter = self - .edges_db - .lazily_decode_data() - .prefix_iter(txn, &out_key)?; - - let prefix_len = out_key.len(); - - for result in iter { - let (key, _) = result?; - - let mut arr = [0u8; 16]; - arr[..16].copy_from_slice(&key[prefix_len..(prefix_len + 16)]); - let neighbor_id = u128::from_be_bytes(arr); - - if neighbor_id == id { - continue; - } - let vector = self.get_raw_vector_data(txn, neighbor_id, label, arena)?; - - let passes_filters = match filter { - Some(filter_slice) => filter_slice.iter().all(|f| f(&vector, txn)), - None => true, - }; - - if passes_filters { - neighbors.push(vector); - } - } - neighbors.shrink_to_fit(); - - Ok(neighbors) - } - - #[inline(always)] - fn set_neighbours<'db: 'arena, 'arena: 'txn, 'txn, 's>( - &'db self, - txn: &'txn mut RwTxn<'db>, - id: u128, - neighbors: &BinaryHeap<'arena, HVector<'arena>>, - level: usize, - ) -> Result<(), VectorError> { - let prefix = Self::out_edges_key(id, level, None); - - let mut keys_to_delete: HashSet> = self - .edges_db - .prefix_iter(txn, prefix.as_ref())? - .filter_map(|result| result.ok().map(|(key, _)| key.to_vec())) - .collect(); - - neighbors - .iter() - .try_for_each(|neighbor| -> Result<(), VectorError> { - let neighbor_id = neighbor.id; - if neighbor_id == id { - return Ok(()); - } - - let out_key = Self::out_edges_key(id, level, Some(neighbor_id)); - keys_to_delete.remove(&out_key); - self.edges_db.put(txn, &out_key, &())?; - - let in_key = Self::out_edges_key(neighbor_id, level, Some(id)); - keys_to_delete.remove(&in_key); - self.edges_db.put(txn, &in_key, &())?; - - Ok(()) - })?; - - for key in keys_to_delete { - self.edges_db.delete(txn, &key)?; - } - - Ok(()) - } - - fn select_neighbors<'db: 'arena, 'arena: 'txn, 'txn, 's, F>( - &'db self, - txn: &'txn RoTxn<'db>, - label: &'arena str, - query: &'s HVector<'arena>, - mut cands: BinaryHeap<'arena, HVector<'arena>>, - level: usize, - should_extend: bool, - filter: Option<&[F]>, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - { - let m = self.config.m; - - if !should_extend { - return Ok(cands.take_inord(m)); - } - - let mut visited: HashSet = HashSet::new(); - let mut result = BinaryHeap::with_capacity(arena, m * cands.len()); - for candidate in cands.iter() { - for mut neighbor in - self.get_neighbors(txn, label, candidate.id, level, filter, arena)? - { - if !visited.insert(neighbor.id) { - continue; - } - - neighbor.set_distance(neighbor.distance_to(query)?); - - /* - let passes_filters = match filter { - Some(filter_slice) => filter_slice.iter().all(|f| f(&neighbor, txn)), - None => true, - }; - - if passes_filters { - result.push(neighbor); - } - */ - - if filter.is_none() || filter.unwrap().iter().all(|f| f(&neighbor, txn)) { - result.push(neighbor); - } - } - } - - result.extend(cands); - Ok(result.take_inord(m)) - } - - fn search_level<'db: 'arena, 'arena: 'txn, 'txn, 'q, F>( - &self, - txn: &'txn RoTxn<'db>, - label: &'arena str, - query: &'q HVector<'arena>, - entry_point: &'q mut HVector<'arena>, - ef: usize, - level: usize, - filter: Option<&[F]>, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - { - let mut visited: HashSet = HashSet::new(); - let mut candidates: BinaryHeap<'arena, Candidate> = - BinaryHeap::with_capacity(arena, self.config.ef_construct); - let mut results: BinaryHeap<'arena, HVector<'arena>> = BinaryHeap::new(arena); - - entry_point.set_distance(entry_point.distance_to(query)?); - candidates.push(Candidate { - id: entry_point.id, - distance: entry_point.get_distance(), - }); - results.push(*entry_point); - visited.insert(entry_point.id); - - while let Some(curr_cand) = candidates.pop() { - if results.len() >= ef - && results - .get_max() - .is_none_or(|f| curr_cand.distance > f.get_distance()) - { - break; - } - - let max_distance = if results.len() >= ef { - results.get_max().map(|f| f.get_distance()) - } else { - None - }; - - self.get_neighbors(txn, label, curr_cand.id, level, filter, arena)? - .into_iter() - .filter(|neighbor| visited.insert(neighbor.id)) - .filter_map(|mut neighbor| { - let distance = neighbor.distance_to(query).ok()?; - - if max_distance.is_none_or(|max| distance < max) { - neighbor.set_distance(distance); - Some((neighbor, distance)) - } else { - None - } - }) - .for_each(|(neighbor, distance)| { - candidates.push(Candidate { - id: neighbor.id, - distance, - }); - - results.push(neighbor); - - if results.len() > ef { - results = results.take_inord(ef); - } - }); - } - Ok(results) - } - - pub fn num_inserted_vectors(&self, txn: &RoTxn) -> Result { - Ok(self.vectors_db.len(txn)?) - } - - #[inline] - pub fn get_vector_properties<'db: 'arena, 'arena: 'txn, 'txn>( - &self, - txn: &'txn RoTxn<'db>, - id: u128, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> { - let vector: Option> = - match self.vector_properties_db.get(txn, &id)? { - Some(bytes) => Some(VectorWithoutData::from_bincode_bytes(arena, bytes, id)?), - None => None, - }; - - if let Some(vector) = vector - && vector.deleted - { - return Err(VectorError::VectorDeleted); - } - - Ok(vector) - } - - #[inline(always)] - pub fn get_full_vector<'arena>( - &self, - txn: &RoTxn, - id: u128, - arena: &'arena bumpalo::Bump, - ) -> Result, VectorError> { - let vector_data_bytes = self - .vectors_db - .get(txn, &Self::vector_key(id, 0))? - .ok_or(VectorError::VectorNotFound(uuid_str(id, arena).to_string()))?; - - let properties_bytes = self.vector_properties_db.get(txn, &id)?; - - let vector = HVector::from_bincode_bytes(arena, properties_bytes, vector_data_bytes, id)?; - if vector.deleted { - return Err(VectorError::VectorDeleted); - } - Ok(vector) - } - - #[inline(always)] - pub fn get_raw_vector_data<'db: 'arena, 'arena: 'txn, 'txn>( - &self, - txn: &'txn RoTxn<'db>, - id: u128, - label: &'arena str, - arena: &'arena bumpalo::Bump, - ) -> Result, VectorError> { - let vector_data_bytes = self - .vectors_db - .get(txn, &Self::vector_key(id, 0))? - .ok_or(VectorError::VectorNotFound(uuid_str(id, arena).to_string()))?; - HVector::from_raw_vector_data(arena, vector_data_bytes, label, id) - } - - /// Get all vectors from the database, optionally filtered by level - pub fn get_all_vectors<'db: 'arena, 'arena: 'txn, 'txn>( - &self, - txn: &'txn RoTxn<'db>, - level: Option, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> { - let mut vectors = bumpalo::collections::Vec::new_in(arena); - - // Iterate over all vectors in the database - let prefix_iter = self.vectors_db.prefix_iter(txn, VECTOR_PREFIX)?; - - for result in prefix_iter { - let (key, _) = result?; - - // Extract id from the key: v: (2 bytes) + id (16 bytes) + level (8 bytes) - if key.len() < VECTOR_PREFIX.len() + 16 { - continue; // Skip malformed keys - } - - let mut id_bytes = [0u8; 16]; - id_bytes.copy_from_slice(&key[VECTOR_PREFIX.len()..VECTOR_PREFIX.len() + 16]); - let id = u128::from_be_bytes(id_bytes); - - // Get the full vector using the existing method - match self.get_full_vector(txn, id, arena) { - Ok(vector) => { - // Filter by level if specified - if let Some(lvl) = level { - if vector.level == lvl { - vectors.push(vector); - } - } else { - vectors.push(vector); - } - } - Err(_) => { - // Skip vectors that can't be loaded (e.g., deleted) - continue; - } - } - } - - Ok(vectors) - } -} - -impl HNSW for VectorCore { - fn search<'db, 'arena, 'txn, F>( - &self, - txn: &'txn RoTxn<'db>, - query: &'arena [f64], - k: usize, - label: &'arena str, - filter: Option<&'arena [F]>, - should_trickle: bool, - arena: &'arena bumpalo::Bump, - ) -> Result>, VectorError> - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - 'db: 'arena, - 'arena: 'txn, - { - let query = HVector::from_slice(label, 0, query); - // let temp_arena = bumpalo::Bump::new(); - - let mut entry_point = self.get_entry_point(txn, label, arena)?; - - let ef = self.config.ef; - let curr_level = entry_point.level; - // println!("curr_level: {curr_level}"); - for level in (1..=curr_level).rev() { - let mut nearest = self.search_level( - txn, - label, - &query, - &mut entry_point, - ef, - level, - match should_trickle { - true => filter, - false => None, - }, - arena, - )?; - if let Some(closest) = nearest.pop() { - entry_point = closest; - } - } - // println!("entry_point: {entry_point:?}"); - let candidates = self.search_level( - txn, - label, - &query, - &mut entry_point, - ef, - 0, - match should_trickle { - true => filter, - false => None, - }, - arena, - )?; - // println!("candidates"); - let results = candidates.to_vec_with_filter::( - k, - filter, - label, - txn, - self.vector_properties_db, - arena, - )?; - - debug_println!("vector search found {} results", results.len()); - Ok(results) - } - - fn insert<'db, 'arena, 'txn, F>( - &'db self, - txn: &'txn mut RwTxn<'db>, - label: &'arena str, - data: &'arena [f64], - properties: Option>, - arena: &'arena bumpalo::Bump, - ) -> Result, VectorError> - where - F: Fn(&HVector<'arena>, &RoTxn<'db>) -> bool, - 'db: 'arena, - 'arena: 'txn, - { - let new_level = self.get_new_level(); - - let mut query = HVector::from_slice(label, 0, data); - query.properties = properties; - self.put_vector(txn, &query)?; - - query.level = new_level; - - let entry_point = match self.get_entry_point(txn, label, arena) { - Ok(ep) => ep, - Err(_) => { - // TODO: use proper error handling - self.set_entry_point(txn, &query)?; - query.set_distance(0.0); - - return Ok(query); - } - }; - - let l = entry_point.level; - let mut curr_ep = entry_point; - for level in (new_level + 1..=l).rev() { - let mut nearest = - self.search_level::(txn, label, &query, &mut curr_ep, 1, level, None, arena)?; - curr_ep = nearest.pop().ok_or(VectorError::VectorCoreError( - "emtpy search result".to_string(), - ))?; - } - - for level in (0..=l.min(new_level)).rev() { - let nearest = self.search_level::( - txn, - label, - &query, - &mut curr_ep, - self.config.ef_construct, - level, - None, - arena, - )?; - curr_ep = *nearest.peek().ok_or(VectorError::VectorCoreError( - "emtpy search result".to_string(), - ))?; - - let neighbors = - self.select_neighbors::(txn, label, &query, nearest, level, true, None, arena)?; - self.set_neighbours(txn, query.id, &neighbors, level)?; - - for e in neighbors { - let id = e.id; - let e_conns = BinaryHeap::from( - arena, - self.get_neighbors::(txn, label, id, level, None, arena)?, - ); - let e_new_conn = self - .select_neighbors::(txn, label, &query, e_conns, level, true, None, arena)?; - self.set_neighbours(txn, id, &e_new_conn, level)?; - } - } - - if new_level > l { - self.set_entry_point(txn, &query)?; - } - - debug_println!("vector inserted with id {}", query.id); - Ok(query) - } - - fn delete(&self, txn: &mut RwTxn, id: u128, arena: &bumpalo::Bump) -> Result<(), VectorError> { - match self.get_vector_properties(txn, id, arena)? { - Some(mut properties) => { - debug_println!("properties: {properties:?}"); - if properties.deleted { - return Err(VectorError::VectorAlreadyDeleted(id.to_string())); - } - - properties.deleted = true; - self.vector_properties_db - .put(txn, &id, &bincode::serialize(&properties)?)?; - debug_println!("vector deleted with id {}", &id); - Ok(()) - } - None => Err(VectorError::VectorNotFound(id.to_string())), - } - } -} diff --git a/helix-db/src/helix_engine/vector_core/vector_distance.rs b/helix-db/src/helix_engine/vector_core/vector_distance.rs deleted file mode 100644 index d92737e25..000000000 --- a/helix-db/src/helix_engine/vector_core/vector_distance.rs +++ /dev/null @@ -1,157 +0,0 @@ -use crate::helix_engine::{types::VectorError, vector_core::vector::HVector}; - -pub const MAX_DISTANCE: f64 = 2.0; -pub const ORTHOGONAL: f64 = 1.0; -pub const MIN_DISTANCE: f64 = 0.0; - -pub trait DistanceCalc { - fn distance(from: &HVector, to: &HVector) -> Result; -} -impl<'a> DistanceCalc for HVector<'a> { - /// Calculates the distance between two vectors. - /// - /// It normalizes the distance to be between 0 and 2. - /// - /// - 1.0 (most similar) → Distance 0.0 (closest) - /// - 0.0 (orthogonal) → Distance 1.0 - /// - -1.0 (most dissimilar) → Distance 2.0 (furthest) - #[inline(always)] - #[cfg(feature = "cosine")] - fn distance(from: &HVector, to: &HVector) -> Result { - cosine_similarity(from.data, to.data).map(|sim| 1.0 - sim) - } -} - -#[inline] -#[cfg(feature = "cosine")] -pub fn cosine_similarity(from: &[f64], to: &[f64]) -> Result { - let len = from.len(); - let other_len = to.len(); - - if len != other_len { - println!("mis-match in vector dimensions!\n{len} != {other_len}"); - return Err(VectorError::InvalidVectorLength); - } - //debug_assert_eq!(len, other.data.len(), "Vectors must have the same length"); - - #[cfg(target_feature = "avx2")] - { - return cosine_similarity_avx2(from, to); - } - - let mut dot_product = 0.0; - let mut magnitude_a = 0.0; - let mut magnitude_b = 0.0; - - const CHUNK_SIZE: usize = 8; - let chunks = len / CHUNK_SIZE; - let remainder = len % CHUNK_SIZE; - - for i in 0..chunks { - let offset = i * CHUNK_SIZE; - let a_chunk = &from[offset..offset + CHUNK_SIZE]; - let b_chunk = &to[offset..offset + CHUNK_SIZE]; - - let mut local_dot = 0.0; - let mut local_mag_a = 0.0; - let mut local_mag_b = 0.0; - - for j in 0..CHUNK_SIZE { - let a_val = a_chunk[j]; - let b_val = b_chunk[j]; - local_dot += a_val * b_val; - local_mag_a += a_val * a_val; - local_mag_b += b_val * b_val; - } - - dot_product += local_dot; - magnitude_a += local_mag_a; - magnitude_b += local_mag_b; - } - - let remainder_offset = chunks * CHUNK_SIZE; - for i in 0..remainder { - let a_val = from[remainder_offset + i]; - let b_val = to[remainder_offset + i]; - dot_product += a_val * b_val; - magnitude_a += a_val * a_val; - magnitude_b += b_val * b_val; - } - - if magnitude_a.abs() == 0.0 || magnitude_b.abs() == 0.0 { - return Ok(-1.0); - } - - Ok(dot_product / (magnitude_a.sqrt() * magnitude_b.sqrt())) -} - -// SIMD implementation using AVX2 (256-bit vectors) -#[cfg(target_feature = "avx2")] -#[inline(always)] -pub fn cosine_similarity_avx2(a: &[f64], b: &[f64]) -> f64 { - use std::arch::x86_64::*; - - let len = a.len(); - let chunks = len / 4; // AVX2 processes 4 f64 values at once - - unsafe { - let mut dot_product = _mm256_setzero_pd(); - let mut magnitude_a = _mm256_setzero_pd(); - let mut magnitude_b = _mm256_setzero_pd(); - - for i in 0..chunks { - let offset = i * 4; - - // Load data - handle unaligned data - let a_chunk = _mm256_loadu_pd(&a[offset]); - let b_chunk = _mm256_loadu_pd(&b[offset]); - - // Calculate dot product and magnitudes in parallel - dot_product = _mm256_add_pd(dot_product, _mm256_mul_pd(a_chunk, b_chunk)); - magnitude_a = _mm256_add_pd(magnitude_a, _mm256_mul_pd(a_chunk, a_chunk)); - magnitude_b = _mm256_add_pd(magnitude_b, _mm256_mul_pd(b_chunk, b_chunk)); - } - - // Horizontal sum of 4 doubles in each vector - let dot_sum = horizontal_sum_pd(dot_product); - let mag_a_sum = horizontal_sum_pd(magnitude_a); - let mag_b_sum = horizontal_sum_pd(magnitude_b); - - // Handle remainder elements - let mut dot_remainder = 0.0; - let mut mag_a_remainder = 0.0; - let mut mag_b_remainder = 0.0; - - let remainder_offset = chunks * 4; - for i in remainder_offset..len { - let a_val = a[i]; - let b_val = b[i]; - dot_remainder += a_val * b_val; - mag_a_remainder += a_val * a_val; - mag_b_remainder += b_val * b_val; - } - - // Combine SIMD and scalar results - let dot_product_total = dot_sum + dot_remainder; - let magnitude_a_total = (mag_a_sum + mag_a_remainder).sqrt(); - let magnitude_b_total = (mag_b_sum + mag_b_remainder).sqrt(); - - dot_product_total / (magnitude_a_total * magnitude_b_total) - } -} - -// Helper function to sum the 4 doubles in an AVX2 vector -#[cfg(target_feature = "avx2")] -#[inline(always)] -unsafe fn horizontal_sum_pd(__v: __m256d) -> f64 { - use std::arch::x86_64::*; - - // Extract the high 128 bits and add to the low 128 bits - let sum_hi_lo = _mm_add_pd(_mm256_castpd256_pd128(__v), _mm256_extractf128_pd(__v, 1)); - - // Add the high 64 bits to the low 64 bits - let sum = _mm_add_sd(sum_hi_lo, _mm_unpackhi_pd(sum_hi_lo, sum_hi_lo)); - - // Extract the low 64 bits as a scalar - _mm_cvtsd_f64(sum) -} diff --git a/helix-db/src/helix_engine/vector_core/vector_without_data.rs b/helix-db/src/helix_engine/vector_core/vector_without_data.rs deleted file mode 100644 index 8d756094b..000000000 --- a/helix-db/src/helix_engine/vector_core/vector_without_data.rs +++ /dev/null @@ -1,153 +0,0 @@ -use crate::{ - helix_engine::types::VectorError, - protocol::{custom_serde::vector_serde::VectoWithoutDataDeSeed, value::Value}, - utils::{id::uuid_str_from_buf, properties::ImmutablePropertiesMap}, -}; -use bincode::Options; -use core::fmt; -use serde::{Serialize, ser::SerializeMap}; -use std::fmt::Debug; -// TODO: make this generic over the type of encoding (f32, f64, etc) -// TODO: use const param to set dimension -// TODO: set level as u8 - -#[repr(C, align(16))] -#[derive(Clone, Copy)] -pub struct VectorWithoutData<'arena> { - /// The id of the HVector - pub id: u128, - /// The label of the HVector - pub label: &'arena str, - /// the version of the vector - pub version: u8, - /// whether the vector is deleted - pub deleted: bool, - /// The level of the HVector - pub level: usize, - - /// The properties of the HVector - pub properties: Option>, -} - -// Custom Serialize implementation to conditionally include id field -// For JSON serialization, the id field is included, but for bincode it is skipped -impl<'arena> Serialize for VectorWithoutData<'arena> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - - // Check if this is a human-readable format (like JSON) - if serializer.is_human_readable() { - // Include id for JSON serialization - let mut buffer = [0u8; 36]; - let mut state = serializer.serialize_map(Some( - 6 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0), - ))?; - state.serialize_entry("id", uuid_str_from_buf(self.id, &mut buffer))?; - state.serialize_entry("label", self.label)?; - state.serialize_entry("version", &self.version)?; - state.serialize_entry("deleted", &self.deleted)?; - state.serialize_entry("level", &self.level)?; - if let Some(properties) = &self.properties { - for (key, value) in properties.iter() { - state.serialize_entry(key, value)?; - } - } - state.end() - } else { - // Skip id for bincode serialization - let mut state = serializer.serialize_struct("VectorWithoutData", 5)?; - state.serialize_field("label", self.label)?; - state.serialize_field("version", &self.version)?; - state.serialize_field("deleted", &self.deleted)?; - state.serialize_field("level", &self.level)?; - state.serialize_field("properties", &self.properties)?; - state.end() - } - } -} - -impl Debug for VectorWithoutData<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{{ \nid: {},\nlevel: {} }}", - uuid::Uuid::from_u128(self.id), - self.level, - ) - } -} - -impl<'arena> VectorWithoutData<'arena> { - #[inline(always)] - pub fn from_properties( - id: u128, - label: &'arena str, - level: usize, - properties: ImmutablePropertiesMap<'arena>, - ) -> Self { - VectorWithoutData { - id, - label, - version: 1, - level, - properties: Some(properties), - deleted: false, - } - } - - pub fn from_bincode_bytes<'txn>( - arena: &'arena bumpalo::Bump, - properties: &'txn [u8], - id: u128, - ) -> Result { - bincode::options() - .with_fixint_encoding() - .allow_trailing_bytes() - .deserialize_seed(VectoWithoutDataDeSeed { arena, id }, properties) - .map_err(|e| VectorError::ConversionError(format!("Error deserializing vector: {e}"))) - } - - #[inline(always)] - pub fn to_bincode_bytes(&self) -> Result, bincode::Error> { - bincode::serialize(self) - } - /// Returns the id of the HVector - #[inline(always)] - pub fn get_id(&self) -> u128 { - self.id - } - - /// Returns the level of the HVector - #[inline(always)] - pub fn get_level(&self) -> usize { - self.level - } - - #[inline(always)] - pub fn get_label(&self) -> &'arena str { - self.label - } - - #[inline(always)] - pub fn get_property(&self, key: &str) -> Option<&'arena Value> { - self.properties.as_ref().and_then(|value| value.get(key)) - } - - pub fn id(&self) -> &u128 { - &self.id - } - - pub fn label(&self) -> &'arena str { - self.label - } -} - -impl PartialEq for VectorWithoutData<'_> { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} -impl Eq for VectorWithoutData<'_> {} diff --git a/helix-db/src/helix_engine/vector_core/version.rs b/helix-db/src/helix_engine/vector_core/version.rs new file mode 100644 index 000000000..322d44d11 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/version.rs @@ -0,0 +1,90 @@ +use std::mem::size_of; +use std::{borrow::Cow, fmt}; + +use byteorder::{BigEndian, ByteOrder}; +use heed3::BoxedError; + +#[derive(Debug, Clone, Copy)] +pub struct Version { + pub major: u32, + pub minor: u32, + pub patch: u32, +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "v{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl Version { + pub fn current() -> Self { + Version { + major: env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(), + minor: env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(), + patch: env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(), + } + } +} + +pub enum VersionCodec {} + +impl<'a> heed3::BytesEncode<'a> for VersionCodec { + type EItem = Version; + + fn bytes_encode(item: &'a Self::EItem) -> Result, BoxedError> { + let Version { + major, + minor, + patch, + } = item; + + let mut output = Vec::with_capacity(size_of::() * 3); + output.extend_from_slice(&major.to_be_bytes()); + output.extend_from_slice(&minor.to_be_bytes()); + output.extend_from_slice(&patch.to_be_bytes()); + + Ok(Cow::Owned(output)) + } +} + +impl heed3::BytesDecode<'_> for VersionCodec { + type DItem = Version; + + fn bytes_decode(bytes: &'_ [u8]) -> Result { + let major = BigEndian::read_u32(bytes); + let bytes = &bytes[size_of_val(&major)..]; + let minor = BigEndian::read_u32(bytes); + let bytes = &bytes[size_of_val(&minor)..]; + let patch = BigEndian::read_u32(bytes); + + Ok(Version { + major, + minor, + patch, + }) + } +} + +#[cfg(test)] +mod test { + use heed3::{BytesDecode, BytesEncode}; + + use super::*; + + #[test] + fn version_codec() { + let version = Version { + major: 0, + minor: 10, + patch: 100, + }; + + let encoded = VersionCodec::bytes_encode(&version).unwrap(); + let decoded = VersionCodec::bytes_decode(&encoded).unwrap(); + + assert_eq!(version.major, decoded.major); + assert_eq!(version.minor, decoded.minor); + assert_eq!(version.patch, decoded.patch); + } +} diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs new file mode 100644 index 000000000..ad543fc92 --- /dev/null +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -0,0 +1,430 @@ +use std::path::PathBuf; + +use heed3::{ + RoTxn, RwTxn, + types::{DecodeIgnore, Unit}, +}; +use rand::{Rng, SeedableRng}; +use roaring::RoaringBitmap; + +use crate::helix_engine::vector_core::{ + CoreDatabase, ItemId, VectorCoreResult, VectorError, + distance::Distance, + hnsw::HnswBuilder, + item_iter::ItemIter, + key::{Key, KeyCodec, Prefix, PrefixCodec}, + metadata::{Metadata, MetadataCodec}, + node::{Item, ItemIds, Links, Node}, + parallel::{ImmutableItems, ImmutableLinks}, + unaligned_vector::UnalignedVector, + version::{Version, VersionCodec}, +}; + +pub struct VectorBuilder<'a, D: Distance, R: Rng + SeedableRng> { + writer: &'a Writer, + rng: &'a mut R, + inner: BuildOption, +} + +pub(crate) struct BuildOption { + pub(crate) ef_construction: usize, + pub(crate) alpha: f32, + pub(crate) available_memory: Option, +} + +impl Default for BuildOption { + fn default() -> Self { + Self { + ef_construction: 100, + alpha: 1.0, + available_memory: None, + } + } +} + +impl<'a, D: Distance, R: Rng + SeedableRng> VectorBuilder<'a, D, R> { + /// Controls the search range when inserting a new item into the graph. This value must be + /// greater than or equal to the `M` used in [`Self::build`] + /// + /// Typical values range from 50 to 500, with larger `ef_construction` producing higher + /// quality hnsw graphs at the expense of longer builds. The default value used in hannoy is + /// 100. + /// + /// # Example + /// + /// ```no_run + /// # use hannoy::{Writer, distances::Euclidean}; + /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); + /// use rand::rngs::StdRng; + /// use rand::SeedableRng; + /// + /// let mut rng = StdRng::seed_from_u64(4729); + /// writer.builder(&mut rng).ef_construction(100).build::<16,32>(&mut wtxn); + /// ``` + pub fn ef_construction(&mut self, ef_construction: usize) -> &mut Self { + self.inner.ef_construction = ef_construction; + self + } + + /// Tunable hyperparameter for the graph building process. Alpha decreases the tolerance for + /// link creation during index time. Alpha = 1 is the normal HNSW build while alpha > 1 is + /// more similar to DiskANN. Increasing alpha increases indexing times as more neighbours are + /// considered per linking step, but results in higher recall. + /// + /// DiskANN authors suggest using alpha=1.1 or alpha=1.2. By default alpha=1.0 in hannoy. + /// + /// # Example + /// + /// ```no_run + /// # use hannoy::{Writer, distances::Euclidean}; + /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); + /// use rand::rngs::StdRng; + /// use rand::SeedableRng; + /// + /// let mut rng = StdRng::seed_from_u64(4729); + /// writer.builder(&mut rng).alpha(1.1).build::<16,32>(&mut wtxn); + /// ``` + pub fn alpha(&mut self, alpha: f32) -> &mut Self { + self.inner.alpha = alpha; + self + } + + /// Generates an HNSW graph with max `M` links per node in layers > 0 and max `M0` links in layer 0. + /// + /// A general rule of thumb is to take `M0`= 2*`M`, with `M` >=3. Some common choices for + /// `M` include : 8, 12, 16, 32. Note that increasing `M` produces a denser graph at the cost + /// of longer build times. + /// + /// This function is using rayon to spawn threads. It can be configured by using the + /// [`rayon::ThreadPoolBuilder`]. + /// + /// # Example + /// + /// ```no_run + /// # use hannoy::{Writer, distances::Euclidean}; + /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); + /// use rayon; + /// use rand::rngs::StdRng; + /// use rand::SeedableRng; + /// + /// // configure global threadpool if you want! + /// rayon::ThreadPoolBuilder::new().num_threads(4).build_global().unwrap(); + /// + /// let mut rng = StdRng::seed_from_u64(4729); + /// writer.builder(&mut rng).build::<16,32>(&mut wtxn); + /// ``` + pub fn build( + &mut self, + wtxn: &mut RwTxn, + ) -> VectorCoreResult<()> { + self.writer.build::(wtxn, self.rng, &self.inner) + } +} + +/// A writer to store new items, remove existing ones, and build the search +/// index to query the nearest neighbors to items or vectors. +#[derive(Debug)] +pub struct Writer { + database: CoreDatabase, + index: u16, + dimensions: usize, + /// The folder in which tempfile will write its temporary files. + tmpdir: Option, +} + +impl Writer { + /// Creates a new writer from a database, index and dimensions. + pub fn new(database: CoreDatabase, index: u16, dimensions: usize) -> Writer { + Writer { + database, + index, + dimensions, + tmpdir: None, + } + } + + /// Sets the path to the temporary directory where files are written. + pub fn set_tmpdir(&mut self, path: impl Into) { + self.tmpdir = Some(path.into()); + } + + /// Returns `true` if the index is empty. + pub fn is_empty(&self, rtxn: &RoTxn, arena: &bumpalo::Bump) -> VectorCoreResult { + self.iter(rtxn, arena).map(|mut iter| iter.next().is_none()) + } + + /// Returns `true` if the index needs to be built before being able to read in it. + pub fn need_build(&self, rtxn: &RoTxn) -> VectorCoreResult { + Ok(self + .database + .remap_types::() + .prefix_iter(rtxn, &Prefix::updated(self.index))? + .remap_key_type::() + .next() + .is_some() + || self + .database + .remap_data_type::() + .get(rtxn, &Key::metadata(self.index))? + .is_none()) + } + + /// Returns `true` if the database contains the given item. + pub fn contains_item(&self, rtxn: &RoTxn, item: ItemId) -> VectorCoreResult { + self.database + .remap_data_type::() + .get(rtxn, &Key::item(self.index, item)) + .map(|opt| opt.is_some()) + .map_err(Into::into) + } + + /// Returns an iterator over the items vector. + pub fn iter<'t>( + &self, + rtxn: &'t RoTxn, + arena: &'t bumpalo::Bump, + ) -> VectorCoreResult> { + Ok(ItemIter::new( + self.database, + self.index, + self.dimensions, + rtxn, + arena, + )?) + } + + /// Add an item associated to a vector in the database. + pub fn add_item(&self, wtxn: &mut RwTxn, item: ItemId, vector: &[f32]) -> VectorCoreResult<()> { + if vector.len() != self.dimensions { + return Err(VectorError::InvalidVecDimension { + expected: self.dimensions, + received: vector.len(), + }); + } + + let vector = UnalignedVector::from_slice(vector); + let db_item = Item { + header: D::new_header(&vector), + vector, + }; + self.database + .put(wtxn, &Key::item(self.index, item), &Node::Item(db_item))?; + self.database + .remap_data_type::() + .put(wtxn, &Key::updated(self.index, item), &())?; + + Ok(()) + } + + /// Deletes an item stored in this database and returns `true` if it existed. + pub fn del_item(&self, wtxn: &mut RwTxn, item: ItemId) -> VectorCoreResult { + if self.database.delete(wtxn, &Key::item(self.index, item))? { + self.database.remap_data_type::().put( + wtxn, + &Key::updated(self.index, item), + &(), + )?; + + Ok(true) + } else { + Ok(false) + } + } + + /// Removes everything in the database, user items and internal graph links. + pub fn clear(&self, wtxn: &mut RwTxn) -> VectorCoreResult<()> { + let mut cursor = self + .database + .remap_key_type::() + .prefix_iter_mut(wtxn, &Prefix::all(self.index))? + .remap_types::(); + + while let Some((_id, _node)) = cursor.next().transpose()? { + // SAFETY: Safe because we don't keep any references to the entry + unsafe { cursor.del_current() }?; + } + + Ok(()) + } + + pub fn builder<'a, R>(&'a self, rng: &'a mut R) -> VectorBuilder<'a, D, R> + where + R: Rng + SeedableRng, + { + VectorBuilder { + writer: self, + rng, + inner: BuildOption::default(), + } + } + + fn build( + &self, + wtxn: &mut RwTxn, + rng: &mut R, + options: &BuildOption, + ) -> VectorCoreResult<()> + where + R: Rng + SeedableRng, + { + let item_indices = self.item_indices(wtxn)?; + // updated items can be an update, an addition or a removed item + let updated_items = self.reset_and_retrieve_updated_items(wtxn)?; + + let to_delete = updated_items.clone() - &item_indices; + let to_insert = &item_indices & &updated_items; + + let metadata = self + .database + .remap_data_type::() + .get(wtxn, &Key::metadata(self.index))?; + + let (entry_points, max_level) = metadata.as_ref().map_or_else( + || (Vec::new(), usize::MIN), + |metadata| { + ( + metadata.entry_points.iter().collect(), + metadata.max_level as usize, + ) + }, + ); + + // we should not keep a reference to the metadata since they're going to be moved by LMDB + drop(metadata); + + let mut hnsw = HnswBuilder::::new(options) + .with_entry_points(entry_points) + .with_max_level(max_level); + + let _ = hnsw.build(to_insert, &to_delete, self.database, self.index, wtxn, rng)?; + + // Remove deleted links from lmdb AFTER build; in DiskANN we use a deleted item's + // neighbours when filling in the "gaps" left in the graph from deletions. See + // [`HnswBuilder::maybe_patch_old_links`] for more details. + self.delete_links_from_db(to_delete, wtxn)?; + + let metadata = Metadata { + dimensions: self.dimensions.try_into().unwrap(), + items: item_indices, + entry_points: ItemIds::from_slice(&hnsw.entry_points), + max_level: hnsw.max_level as u8, + distance: D::name(), + }; + self.database.remap_data_type::().put( + wtxn, + &Key::metadata(self.index), + &metadata, + )?; + self.database.remap_data_type::().put( + wtxn, + &Key::version(self.index), + &Version::current(), + )?; + + Ok(()) + } + + fn reset_and_retrieve_updated_items( + &self, + wtxn: &mut RwTxn, + ) -> VectorCoreResult { + let mut updated_items = RoaringBitmap::new(); + let mut updated_iter = self + .database + .remap_types::() + .prefix_iter_mut(wtxn, &Prefix::updated(self.index))? + .remap_key_type::(); + + while let Some((key, _)) = updated_iter.next().transpose()? { + let inserted = updated_items.insert(key.node.item); + debug_assert!(inserted, "The keys should be sorted by LMDB"); + // SAFETY: Safe because we don't hold any reference to the database currently + unsafe { updated_iter.del_current()? }; + } + Ok(updated_items) + } + + // Fetches the item's ids, not the links. + fn item_indices(&self, wtxn: &mut RwTxn) -> VectorCoreResult { + let mut indices = RoaringBitmap::new(); + for (_, result) in self + .database + .remap_types::() + .prefix_iter(wtxn, &Prefix::item(self.index))? + .remap_key_type::() + .enumerate() + { + let (i, _) = result?; + indices.insert(i.node.unwrap_item()); + } + + Ok(indices) + } + + // Iterates over links in lmdb and deletes those in `to_delete`. There can be several links + // with the same NodeId.item, each differing by their layer + fn delete_links_from_db( + &self, + to_delete: RoaringBitmap, + wtxn: &mut RwTxn, + ) -> VectorCoreResult<()> { + let mut cursor = self + .database + .remap_key_type::() + .prefix_iter_mut(wtxn, &Prefix::links(self.index))? + .remap_types::(); + + while let Some((key, _)) = cursor.next().transpose()? { + if to_delete.contains(key.node.item) { + // SAFETY: Safe because we don't keep any references to the entry + unsafe { cursor.del_current() }?; + } + } + + Ok(()) + } +} + +#[derive(Clone)] +pub(crate) struct FrozenReader<'a, D: Distance> { + pub index: u16, + pub items: &'a ImmutableItems<'a, D>, + pub links: &'a ImmutableLinks<'a, D>, +} + +impl<'a, D: Distance> FrozenReader<'a, D> { + pub fn get_item(&self, item_id: ItemId) -> VectorCoreResult> { + let key = Key::item(self.index, item_id); + // key is a `Key::item` so returned result must be a Node::Item + self.items + .get(item_id)? + .ok_or(VectorError::missing_key(key)) + } + + pub fn get_links(&self, item_id: ItemId, level: usize) -> VectorCoreResult> { + let key = Key::links(self.index, item_id, level as u8); + // key is a `Key::item` so returned result must be a Node::Item + self.links + .get(item_id, level as u8)? + .ok_or(VectorError::missing_key(key)) + } +} + +/// Clears all the links. Starts from the last node and stops at the first item. +fn clear_links( + wtxn: &mut RwTxn, + database: CoreDatabase, + index: u16, +) -> VectorCoreResult<()> { + let mut cursor = database + .remap_types::() + .prefix_iter_mut(wtxn, &Prefix::links(index))? + .remap_key_type::(); + + while let Some((_id, _node)) = cursor.next().transpose()? { + // SAFETY: Safe because we don't keep any references to the entry + unsafe { cursor.del_current()? }; + } + + Ok(()) +} diff --git a/helix-db/src/helix_gateway/mcp/mcp.rs b/helix-db/src/helix_gateway/mcp/mcp.rs index 239a8c18a..269495d8c 100644 --- a/helix-db/src/helix_gateway/mcp/mcp.rs +++ b/helix-db/src/helix_gateway/mcp/mcp.rs @@ -964,13 +964,13 @@ pub fn search_vector_text(input: &mut MCPToolInput) -> Result bool, _>( + .search_v:: bool, _>( query_vec_arena, k_value, label_arena, - None + None, ) - .collect::,_>>()?; + .collect::, _>>()?; tracing::debug!("[VECTOR_SEARCH] Search returned {} results", results.len()); diff --git a/helix-db/src/protocol/custom_serde/compatibility_tests.rs b/helix-db/src/protocol/custom_serde/compatibility_tests.rs index 8944628dd..27fcd2557 100644 --- a/helix-db/src/protocol/custom_serde/compatibility_tests.rs +++ b/helix-db/src/protocol/custom_serde/compatibility_tests.rs @@ -10,7 +10,7 @@ #[cfg(test)] mod compatibility_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use bumpalo::Bump; @@ -115,7 +115,10 @@ mod compatibility_tests { let props: Vec<(&str, Value)> = (0..50) .map(|i| { - (Box::leak(format!("key_{}", i).into_boxed_str()) as &str, Value::I32(i)) + ( + Box::leak(format!("key_{}", i).into_boxed_str()) as &str, + Value::I32(i), + ) }) .collect(); @@ -184,22 +187,23 @@ mod compatibility_tests { fn test_old_edge_with_nested_values() { let id = 77777u128; - let props = vec![ - ( - "metadata", - Value::Object( - vec![ - ("created".to_string(), Value::I64(1234567890)), - ("tags".to_string(), Value::Array(vec![ + let props = vec![( + "metadata", + Value::Object( + vec![ + ("created".to_string(), Value::I64(1234567890)), + ( + "tags".to_string(), + Value::Array(vec![ Value::String("tag1".to_string()), Value::String("tag2".to_string()), - ])), - ] - .into_iter() - .collect(), - ), + ]), + ), + ] + .into_iter() + .collect(), ), - ]; + )]; let old_edge = create_old_edge(id, "NestedEdge", 0, 10, 20, props); let old_bytes = bincode::serialize(&old_edge).unwrap(); @@ -245,12 +249,8 @@ mod compatibility_tests { let data_bytes = create_vector_bytes(&data); let arena = Bump::new(); - let new_vector = HVector::from_bincode_bytes( - &arena, - Some(&old_bytes), - &data_bytes, - id, - ); + let new_vector = + HVector::from_bincode_bytes(&arena, Some(&old_bytes), &data_bytes, id, true); assert!(new_vector.is_ok(), "Should deserialize old vector format"); let restored = new_vector.unwrap(); @@ -271,12 +271,8 @@ mod compatibility_tests { let data_bytes = create_vector_bytes(&[0.0]); let arena = Bump::new(); - let new_vector = HVector::from_bincode_bytes( - &arena, - Some(&old_bytes), - &data_bytes, - id, - ).unwrap(); + let new_vector = + HVector::from_bincode_bytes(&arena, Some(&old_bytes), &data_bytes, id, true).unwrap(); assert_eq!(new_vector.deleted, true); } @@ -296,12 +292,8 @@ mod compatibility_tests { let data_bytes = create_vector_bytes(&vec![0.0; 1536]); let arena = Bump::new(); - let new_vector = HVector::from_bincode_bytes( - &arena, - Some(&old_bytes), - &data_bytes, - id, - ).unwrap(); + let new_vector = + HVector::from_bincode_bytes(&arena, Some(&old_bytes), &data_bytes, id, true).unwrap(); assert!(new_vector.properties.is_some()); let props = new_vector.properties.unwrap(); @@ -371,18 +363,10 @@ mod compatibility_tests { let data_bytes = create_vector_bytes(&data); let arena2 = Bump::new(); - let restored_v1 = HVector::from_bincode_bytes( - &arena2, - Some(&props_v1), - &data_bytes, - id, - ).unwrap(); - let restored_v2 = HVector::from_bincode_bytes( - &arena2, - Some(&props_v2), - &data_bytes, - id, - ).unwrap(); + let restored_v1 = + HVector::from_bincode_bytes(&arena2, Some(&props_v1), &data_bytes, id, true).unwrap(); + let restored_v2 = + HVector::from_bincode_bytes(&arena2, Some(&props_v2), &data_bytes, id, true).unwrap(); assert_eq!(restored_v1.version, 1); assert_eq!(restored_v2.version, 2); diff --git a/helix-db/src/protocol/custom_serde/edge_case_tests.rs b/helix-db/src/protocol/custom_serde/edge_case_tests.rs index 85463e6d2..5428c6b9e 100644 --- a/helix-db/src/protocol/custom_serde/edge_case_tests.rs +++ b/helix-db/src/protocol/custom_serde/edge_case_tests.rs @@ -13,7 +13,7 @@ #[cfg(test)] mod edge_case_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use bumpalo::Bump; @@ -84,7 +84,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); assert_eq!(result.unwrap().properties.unwrap().len(), 500); } @@ -156,7 +156,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); assert!(result.unwrap().label.len() > 2000); } @@ -245,7 +245,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -354,7 +354,10 @@ mod edge_case_tests { Value::I32(1), Value::Object({ let mut inner = HashMap::new(); - inner.insert("inner_key".to_string(), Value::String("inner_value".to_string())); + inner.insert( + "inner_key".to_string(), + Value::String("inner_value".to_string()), + ); inner }), ]), @@ -366,7 +369,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -451,7 +454,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -516,19 +519,14 @@ mod edge_case_tests { let id = 800800u128; // Subnormal (denormalized) numbers - let data = vec![ - f64::MIN_POSITIVE, - f64::MIN_POSITIVE / 2.0, - 1e-308, - 1e-320, - ]; + let data = vec![f64::MIN_POSITIVE, f64::MIN_POSITIVE / 2.0, 1e-308, 1e-320]; let vector = create_simple_vector(&arena, id, "subnormal", &data); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -543,7 +541,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -578,9 +576,7 @@ mod edge_case_tests { let long_key = "property_key_".repeat(100); // ~1.3KB key let key_ref: &str = arena.alloc_str(&long_key); - let props = vec![ - (key_ref, Value::String("value".to_string())), - ]; + let props = vec![(key_ref, Value::String("value".to_string()))]; let edge = create_arena_edge(&arena, id, "test", 0, 1, 2, props); let bytes = bincode::serialize(&edge).unwrap(); @@ -607,7 +603,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -620,9 +616,7 @@ mod edge_case_tests { let arena = Bump::new(); let id = 404404u128; - let large_array = Value::Array( - (0..1000).map(|i| Value::I32(i)).collect() - ); + let large_array = Value::Array((0..1000).map(|i| Value::I32(i)).collect()); let props = vec![("big_array", large_array)]; let node = create_arena_node(&arena, id, "test", 0, props); @@ -641,7 +635,7 @@ mod edge_case_tests { let string_array = Value::Array( (0..100) .map(|i| Value::String(format!("string_{}", i))) - .collect() + .collect(), ); let props = vec![("strings", string_array)]; @@ -674,7 +668,7 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } @@ -693,9 +687,9 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); - assert_eq!(result.unwrap().data.len(), 8192); + assert_eq!(result.unwrap().len(), 8192); } #[test] @@ -709,10 +703,10 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); - assert!(deserialized.data.iter().all(|&v| v == 0.0)); + assert!(deserialized.data(&arena).iter().all(|&v| v == 0.0)); } #[test] @@ -726,10 +720,15 @@ mod edge_case_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); - assert!(deserialized.data.iter().all(|&v| (v - 42.42).abs() < 1e-10)); + assert!( + deserialized + .data(&arena) + .iter() + .all(|&v| (v - 42.42).abs() < 1e-10) + ); } // ======================================================================== @@ -797,12 +796,13 @@ mod edge_case_tests { }) .collect(); - let vector = create_arena_vector(&arena, id, &"Vec".repeat(200), 255, true, 0, &data, props); + let vector = + create_arena_vector(&arena, id, &"Vec".repeat(200), 255, true, 0, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); } } diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index fda9616ef..e1bc3b411 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -13,7 +13,7 @@ #[cfg(test)] mod error_handling_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use bumpalo::Bump; @@ -204,7 +204,7 @@ mod error_handling_tests { let valid_data = vec![1.0, 2.0, 3.0]; let data_bytes = create_vector_bytes(&valid_data); - let result = HVector::from_bincode_bytes(&arena, Some(empty_bytes), &data_bytes, id); + let result = HVector::from_bincode_bytes(&arena, Some(empty_bytes), &data_bytes, id, true); assert!(result.is_err(), "Should fail on empty property bytes"); } @@ -218,7 +218,8 @@ mod error_handling_tests { let empty_data: &[u8] = &[]; let arena2 = Bump::new(); - let _result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), empty_data, id); + let _result = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), empty_data, id, true); // Should panic due to assertion in cast_raw_vector_data } @@ -242,7 +243,8 @@ mod error_handling_tests { let truncated_props = &props_bytes[..props_bytes.len() / 2]; let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(truncated_props), data_bytes, id); + let result = + HVector::from_bincode_bytes(&arena2, Some(truncated_props), data_bytes, id, true); assert!(result.is_err(), "Should fail on truncated properties"); } @@ -253,7 +255,7 @@ mod error_handling_tests { let garbage: Vec = vec![0xFF; 50]; let data_bytes = create_vector_bytes(&[1.0, 2.0, 3.0]); - let result = HVector::from_bincode_bytes(&arena, Some(&garbage), &data_bytes, id); + let result = HVector::from_bincode_bytes(&arena, Some(&garbage), &data_bytes, id, true); assert!(result.is_err(), "Should fail on garbage property bytes"); } @@ -384,7 +386,8 @@ mod error_handling_tests { } let arena2 = Bump::new(); - let _result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let _result = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); // Should fail on UTF-8 validation } @@ -472,7 +475,7 @@ mod error_handling_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok(), "Should handle u8::MAX version"); assert_eq!(result.unwrap().version, 255); } @@ -519,7 +522,7 @@ mod error_handling_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok(), "Should handle u128::MAX ID"); assert_eq!(result.unwrap().id, u128::MAX); } @@ -589,7 +592,7 @@ mod error_handling_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!( result.is_ok(), "Should handle special float values in vector data" diff --git a/helix-db/src/protocol/custom_serde/integration_tests.rs b/helix-db/src/protocol/custom_serde/integration_tests.rs index 8f84fae04..db6dd19af 100644 --- a/helix-db/src/protocol/custom_serde/integration_tests.rs +++ b/helix-db/src/protocol/custom_serde/integration_tests.rs @@ -11,7 +11,7 @@ #[cfg(test)] mod integration_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use bincode::Options; @@ -203,9 +203,7 @@ mod integration_tests { let arena = Bump::new(); let edges: Vec = (0..20) - .map(|i| { - create_simple_edge(&arena, i as u128, "LINK", i as u128, (i + 1) as u128) - }) + .map(|i| create_simple_edge(&arena, i as u128, "LINK", i as u128, (i + 1) as u128)) .collect(); let serialized: Vec> = edges @@ -251,12 +249,8 @@ mod integration_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -277,12 +271,8 @@ mod integration_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -314,23 +304,17 @@ mod integration_tests { let props_bytes1 = bincode::serialize(&vector).unwrap(); let data_bytes1 = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let vector2 = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes1), - data_bytes1, - id, - ).unwrap(); + let vector2 = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes1), data_bytes1, id, true) + .unwrap(); // Second roundtrip let props_bytes2 = bincode::serialize(&vector2).unwrap(); let data_bytes2 = vector2.vector_data_to_bytes().unwrap(); let arena3 = Bump::new(); - let vector3 = HVector::from_bincode_bytes( - &arena3, - Some(&props_bytes2), - data_bytes2, - id, - ).unwrap(); + let vector3 = + HVector::from_bincode_bytes(&arena3, Some(&props_bytes2), data_bytes2, id, true) + .unwrap(); assert_vectors_semantically_equal(&vector, &vector2); assert_vectors_semantically_equal(&vector2, &vector3); @@ -367,6 +351,7 @@ mod integration_tests { Some(props_bytes), data_bytes, i as u128, + true, ); assert!(result.is_ok()); } @@ -428,6 +413,7 @@ mod integration_tests { Some(&vector_props_bytes), vector_data_bytes, 3, + true, ); assert!(node_restored.is_ok()); @@ -483,9 +469,7 @@ mod integration_tests { let restored: Vec = serialized .iter() .enumerate() - .map(|(i, bytes)| { - Node::from_bincode_bytes(i as u128, bytes, &shared_arena).unwrap() - }) + .map(|(i, bytes)| Node::from_bincode_bytes(i as u128, bytes, &shared_arena).unwrap()) .collect(); assert_eq!(restored.len(), 100); @@ -613,7 +597,11 @@ mod integration_tests { let bytes = bincode::serialize(&node).unwrap(); // Should be relatively small (label + version + empty props indicator) - assert!(bytes.len() < 100, "Empty node should be small, got {} bytes", bytes.len()); + assert!( + bytes.len() < 100, + "Empty node should be small, got {} bytes", + bytes.len() + ); } #[test] diff --git a/helix-db/src/protocol/custom_serde/property_based_tests.rs b/helix-db/src/protocol/custom_serde/property_based_tests.rs index 81f5c7d65..cbb1bbbc1 100644 --- a/helix-db/src/protocol/custom_serde/property_based_tests.rs +++ b/helix-db/src/protocol/custom_serde/property_based_tests.rs @@ -9,7 +9,7 @@ #[cfg(test)] mod property_based_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use bumpalo::Bump; @@ -36,7 +36,9 @@ mod property_based_tests { any::().prop_map(Value::I64), any::().prop_map(Value::U32), any::().prop_map(Value::U64), - any::().prop_filter("Not NaN", |f| !f.is_nan()).prop_map(Value::F64), + any::() + .prop_filter("Not NaN", |f| !f.is_nan()) + .prop_map(Value::F64), any::().prop_map(Value::Boolean), arb_long_string().prop_map(Value::String), Just(Value::Empty), @@ -288,14 +290,15 @@ mod property_based_tests { Some(&props_bytes), data_bytes, id, + true, ).unwrap(); prop_assert_eq!(deserialized.label, label.as_str()); prop_assert_eq!(deserialized.id, id); - prop_assert_eq!(deserialized.data.len(), data.len()); + prop_assert_eq!(deserialized.len(), data.len()); // Check each data point (with floating point tolerance) - for (i, (&orig, &deser)) in data.iter().zip(deserialized.data.iter()).enumerate() { + for (i, (&orig, &deser)) in data.iter().zip(deserialized.data(&arena).iter()).enumerate() { let diff = (orig - deser).abs(); prop_assert!(diff < 1e-10, "Data mismatch at index {}: {} vs {}", i, orig, deser); } @@ -326,10 +329,11 @@ mod property_based_tests { Some(&props_bytes), data_bytes, id, + true, ).unwrap(); prop_assert_eq!(deserialized.deleted, deleted); - prop_assert_eq!(deserialized.data.len(), data.len()); + prop_assert_eq!(deserialized.len(), data.len()); } #[test] @@ -387,6 +391,7 @@ mod property_based_tests { Some(&props_bytes1), data_bytes1, id, + true, ).unwrap(); // Second roundtrip @@ -436,6 +441,7 @@ mod property_based_tests { Some(&props_bytes), data_bytes, id, + true, ).unwrap(); prop_assert_eq!(vector_restored.id, id); } @@ -472,6 +478,7 @@ mod property_based_tests { Some(&props_bytes), data_bytes, id, + true, ).unwrap(); prop_assert_eq!(vector_restored.label, label.as_str()); } diff --git a/helix-db/src/protocol/custom_serde/test_utils.rs b/helix-db/src/protocol/custom_serde/test_utils.rs index 4197744ad..5da31ad64 100644 --- a/helix-db/src/protocol/custom_serde/test_utils.rs +++ b/helix-db/src/protocol/custom_serde/test_utils.rs @@ -5,7 +5,9 @@ #![cfg(test)] -use crate::helix_engine::vector_core::vector::HVector; +use crate::helix_engine::vector_core::HVector; +use crate::helix_engine::vector_core::distance::Cosine; +use crate::helix_engine::vector_core::node::Item; use crate::protocol::value::Value; use crate::utils::items::{Edge, Node}; use crate::utils::properties::ImmutablePropertiesMap; @@ -101,12 +103,7 @@ pub fn create_simple_node<'arena>(arena: &'arena Bump, id: u128, label: &str) -> } /// Creates an old-style Node for compatibility testing -pub fn create_old_node( - id: u128, - label: &str, - version: u8, - props: Vec<(&str, Value)>, -) -> OldNode { +pub fn create_old_node(id: u128, label: &str, version: u8, props: Vec<(&str, Value)>) -> OldNode { if props.is_empty() { OldNode { id, @@ -245,7 +242,7 @@ pub fn create_arena_vector<'arena>( deleted, level, distance: None, - data: data_ref, + data: Some(Item::::from(data_ref, arena)), properties: None, } } else { @@ -263,7 +260,7 @@ pub fn create_arena_vector<'arena>( deleted, level, distance: None, - data: data_ref, + data: Some(Item::::from(data_ref, arena)), properties: Some(props_map), } } @@ -334,7 +331,10 @@ pub fn all_value_types_props() -> Vec<(&'static str, Value)> { ("u16_val", Value::U16(65000)), ("u32_val", Value::U32(4000000)), ("u64_val", Value::U64(18000000000)), - ("u128_val", Value::U128(340282366920938463463374607431768211455)), + ( + "u128_val", + Value::U128(340282366920938463463374607431768211455), + ), ("bool_val", Value::Boolean(true)), ("empty_val", Value::Empty), ] @@ -345,17 +345,16 @@ pub fn nested_value_props() -> Vec<(&'static str, Value)> { vec![ ( "array_val", - Value::Array(vec![ - Value::I32(1), - Value::I32(2), - Value::I32(3), - ]), + Value::Array(vec![Value::I32(1), Value::I32(2), Value::I32(3)]), ), ( "object_val", Value::Object( vec![ - ("nested_key".to_string(), Value::String("nested_value".to_string())), + ( + "nested_key".to_string(), + Value::String("nested_value".to_string()), + ), ("nested_num".to_string(), Value::I32(42)), ] .into_iter() @@ -364,15 +363,14 @@ pub fn nested_value_props() -> Vec<(&'static str, Value)> { ), ( "deeply_nested", - Value::Array(vec![ - Value::Object( - vec![ - ("inner".to_string(), Value::Array(vec![Value::I32(1), Value::I32(2)])), - ] - .into_iter() - .collect(), - ), - ]), + Value::Array(vec![Value::Object( + vec![( + "inner".to_string(), + Value::Array(vec![Value::I32(1), Value::I32(2)]), + )] + .into_iter() + .collect(), + )]), ), ] } @@ -398,11 +396,7 @@ pub fn assert_nodes_semantically_equal(node1: &Node, node2: &Node) { match (&node1.properties, &node2.properties) { (None, None) => {} (Some(props1), Some(props2)) => { - assert_eq!( - props1.len(), - props2.len(), - "Node property counts differ" - ); + assert_eq!(props1.len(), props2.len(), "Node property counts differ"); // Check each property exists and has the same value for (key1, val1) in props1.iter() { if let Some(val2) = props2.get(key1) { @@ -427,11 +421,7 @@ pub fn assert_edges_semantically_equal(edge1: &Edge, edge2: &Edge) { match (&edge1.properties, &edge2.properties) { (None, None) => {} (Some(props1), Some(props2)) => { - assert_eq!( - props1.len(), - props2.len(), - "Edge property counts differ" - ); + assert_eq!(props1.len(), props2.len(), "Edge property counts differ"); for (key1, val1) in props1.iter() { if let Some(val2) = props2.get(key1) { assert_eq!(val1, val2, "Property value differs for key: {}", key1); @@ -450,10 +440,15 @@ pub fn assert_vectors_semantically_equal(vec1: &HVector, vec2: &HVector) { assert_eq!(vec1.label, vec2.label, "Vector labels differ"); assert_eq!(vec1.version, vec2.version, "Vector versions differ"); assert_eq!(vec1.deleted, vec2.deleted, "Vector deleted flags differ"); - assert_eq!(vec1.data.len(), vec2.data.len(), "Vector dimensions differ"); + assert_eq!(vec1.len(), vec2.len(), "Vector dimensions differ"); // Compare vector data with floating point tolerance - for (i, (v1, v2)) in vec1.data.iter().zip(vec2.data.iter()).enumerate() { + for (i, (v1, v2)) in vec1 + .data_borrowed() + .iter() + .zip(vec2.data_borrowed().iter()) + .enumerate() + { assert!( (v1 - v2).abs() < 1e-10, "Vector data differs at index {}: {} vs {}", @@ -466,11 +461,7 @@ pub fn assert_vectors_semantically_equal(vec1: &HVector, vec2: &HVector) { match (&vec1.properties, &vec2.properties) { (None, None) => {} (Some(props1), Some(props2)) => { - assert_eq!( - props1.len(), - props2.len(), - "Vector property counts differ" - ); + assert_eq!(props1.len(), props2.len(), "Vector property counts differ"); for (key1, val1) in props1.iter() { if let Some(val2) = props2.get(key1) { assert_eq!(val1, val2, "Property value differs for key: {}", key1); @@ -502,7 +493,10 @@ pub fn print_byte_comparison(label: &str, bytes1: &[u8], bytes2: &[u8]) { let min_len = bytes1.len().min(bytes2.len()); for (i, (b1, b2)) in bytes1.iter().zip(bytes2.iter()).take(min_len).enumerate() { if b1 != b2 { - println!(" Index {}: bytes1={:02x} ({}), bytes2={:02x} ({})", i, b1, b1, b2, b2); + println!( + " Index {}: bytes1={:02x} ({}), bytes2={:02x} ({})", + i, b1, b1, b2, b2 + ); } } @@ -561,7 +555,9 @@ pub fn random_utf8_string(len: usize) -> String { pub fn random_f64_vector(dimensions: usize) -> Vec { use rand::Rng; let mut rng = rand::rng(); - (0..dimensions).map(|_| rng.random_range(-1.0..1.0)).collect() + (0..dimensions) + .map(|_| rng.random_range(-1.0..1.0)) + .collect() } /// Generates a random Value for property testing diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index 000a9e893..0ab2ef6a4 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -1,5 +1,5 @@ use crate::{ - helix_engine::vector_core::{vector::HVector, vector_without_data::VectorWithoutData}, + helix_engine::vector_core::{HVector, distance::Cosine, node::Item}, utils::properties::{ImmutablePropertiesMap, ImmutablePropertiesMapDeSeed}, }; use serde::de::{DeserializeSeed, Visitor}; @@ -103,7 +103,7 @@ impl<'de, 'txn, 'arena> serde::de::DeserializeSeed<'de> for VectorDeSeed<'txn, ' version, level: 0, distance: None, - data, + data: Some(Item::::from(data, &self.arena)), properties, }) } @@ -128,7 +128,7 @@ pub struct VectoWithoutDataDeSeed<'arena> { } impl<'de, 'arena> serde::de::DeserializeSeed<'de> for VectoWithoutDataDeSeed<'arena> { - type Value = VectorWithoutData<'arena>; + type Value = HVector<'arena>; fn deserialize(self, deserializer: D) -> Result where @@ -140,7 +140,7 @@ impl<'de, 'arena> serde::de::DeserializeSeed<'de> for VectoWithoutDataDeSeed<'ar } impl<'de, 'arena> serde::de::Visitor<'de> for VectorVisitor<'arena> { - type Value = VectorWithoutData<'arena>; + type Value = HVector<'arena>; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("struct VectorWithoutData") @@ -164,13 +164,15 @@ impl<'de, 'arena> serde::de::DeserializeSeed<'de> for VectoWithoutDataDeSeed<'ar .next_element_seed(OptionPropertiesMapDeSeed { arena: self.arena })? .ok_or_else(|| serde::de::Error::custom("Expected properties field"))?; - Ok(VectorWithoutData { + Ok(HVector { id: self.id, label, version, deleted, level: 0, properties, + distance: None, + data: None, }) } } diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index a1c6c0dcc..8b35f1496 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -13,10 +13,9 @@ #[cfg(test)] mod vector_serialization_tests { use super::super::test_utils::*; - use crate::helix_engine::vector_core::vector::HVector; - use crate::helix_engine::vector_core::vector_without_data::VectorWithoutData; + use crate::helix_engine::vector_core::HVector; use crate::protocol::value::Value; - + use bumpalo::Bump; // ======================================================================== @@ -39,12 +38,8 @@ mod vector_serialization_tests { // Deserialize let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -62,12 +57,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -89,12 +80,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -112,12 +99,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_vectors_semantically_equal(&vector, &deserialized); } @@ -135,17 +118,13 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes( - &arena2, - Some(&props_bytes), - data_bytes, - id, - ).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); // Just verify basic structure instead of deep equality due to HashMap ordering assert_eq!(deserialized.id, id); assert_eq!(deserialized.label, "nested"); - assert_eq!(deserialized.data.len(), 3); + assert_eq!(deserialized.data_borrowed().len(), 3); assert!(deserialized.properties.is_some()); assert_eq!(deserialized.properties.unwrap().len(), 3); } @@ -230,7 +209,7 @@ mod vector_serialization_tests { assert_eq!(vector.id, id); assert_eq!(vector.label, label); - assert_eq!(vector.data.len(), 4); + assert_eq!(vector.len(), 4); assert_eq!(vector.version, 1); assert_eq!(vector.deleted, false); assert_eq!(vector.level, 0); @@ -255,7 +234,7 @@ mod vector_serialization_tests { // Deserialize combining both let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); @@ -278,7 +257,7 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); @@ -297,77 +276,13 @@ mod vector_serialization_tests { // Deserialize with serialized empty properties let arena2 = Bump::new(); - let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id); + let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); assert_eq!(deserialized.id, id); assert_eq!(deserialized.label, "no_props"); - assert_eq!(deserialized.data.len(), 3); - assert!(deserialized.properties.is_none()); - } - - // ======================================================================== - // VECTOR WITHOUT DATA TESTS - // ======================================================================== - - #[test] - fn test_vector_without_data_serialization() { - let arena = Bump::new(); - let id = 999000u128; - let label = arena.alloc_str("metadata_only"); - let props = vec![("type", Value::String("embedding".to_string()))]; - let len = props.len(); - let props_iter = props.into_iter().map(|(k, v)| { - let key: &str = arena.alloc_str(k); - (key, v) - }); - let props_map = crate::utils::properties::ImmutablePropertiesMap::new(len, props_iter, &arena); - - let vector_without_data = VectorWithoutData { - id, - label, - version: 1, - deleted: false, - level: 0, - properties: Some(props_map), - }; - - // Serialize and deserialize - let bytes = bincode::serialize(&vector_without_data).unwrap(); - let arena2 = Bump::new(); - let result = VectorWithoutData::from_bincode_bytes(&arena2, &bytes, id); - println!("{:?}", result); - assert!(result.is_ok()); - let deserialized = result.unwrap(); - assert_eq!(deserialized.id, id); - assert_eq!(deserialized.label, label); - assert_eq!(deserialized.version, 1); - assert_eq!(deserialized.deleted, false); - } - - #[test] - fn test_vector_without_data_empty_properties() { - let arena = Bump::new(); - let id = 111000u128; - let label = arena.alloc_str("empty_meta"); - - let vector_without_data = VectorWithoutData { - id, - label, - version: 1, - deleted: false, - level: 0, - properties: None, - }; - - let bytes = bincode::serialize(&vector_without_data).unwrap(); - let arena2 = Bump::new(); - let result = VectorWithoutData::from_bincode_bytes(&arena2, &bytes, id); - - assert!(result.is_ok()); - let deserialized = result.unwrap(); - assert_eq!(deserialized.id, id); + assert_eq!(deserialized.len(), 3); assert!(deserialized.properties.is_none()); } @@ -387,7 +302,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.version, 5); } @@ -404,7 +320,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.deleted, true); } @@ -421,7 +338,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.deleted, false); } @@ -442,7 +360,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.label, "向量测试"); } @@ -459,7 +378,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.label, "🚀🔥💯"); } @@ -476,7 +396,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.label, ""); } @@ -494,7 +415,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.label.len(), 1000); assert_eq!(deserialized.label, long_label); @@ -522,7 +444,8 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); assert_eq!(deserialized.properties.unwrap().len(), 50); } @@ -543,10 +466,11 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); - assert_eq!(deserialized.data.len(), 1); - assert!((deserialized.data[0] - 42.0).abs() < 1e-10); + assert_eq!(deserialized.len(), 1); + assert!((deserialized.data_borrowed()[0] - 42.0).abs() < 1e-10); } #[test] @@ -561,9 +485,10 @@ mod vector_serialization_tests { let data_bytes = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id).unwrap(); + let deserialized = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); - assert_eq!(deserialized.data.len(), 4096); + assert_eq!(deserialized.len(), 4096); } // ======================================================================== @@ -584,7 +509,9 @@ mod vector_serialization_tests { let data_bytes1 = vector.vector_data_to_bytes().unwrap(); let arena2 = Bump::new(); - let deserialized1 = HVector::from_bincode_bytes(&arena2, Some(&props_bytes1), data_bytes1, id).unwrap(); + let deserialized1 = + HVector::from_bincode_bytes(&arena2, Some(&props_bytes1), data_bytes1, id, true) + .unwrap(); // Second roundtrip let props_bytes2 = bincode::serialize(&deserialized1).unwrap(); diff --git a/helix-db/src/utils/properties.rs b/helix-db/src/utils/properties.rs index 843d170f5..42f9800e7 100644 --- a/helix-db/src/utils/properties.rs +++ b/helix-db/src/utils/properties.rs @@ -22,7 +22,7 @@ use crate::protocol::value::Value; /// - All required space is allocated in the arena upfront /// - Key lengths are stored packed for SIMD length check on get. /// - Small n means O(n) is faster than O(1) -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct ImmutablePropertiesMap<'arena> { len: usize, key_lengths: *const usize, From aceea7f2488886a9a9c2ec2cbbd73aeac3dfb61b Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Tue, 18 Nov 2025 23:48:39 -0300 Subject: [PATCH 25/48] Define maximum edges cannot be defined at compile time (unfortunately) --- helix-db/src/helix_engine/vector_core/hnsw.rs | 54 +++++++-------- helix-db/src/helix_engine/vector_core/mod.rs | 22 ++++-- .../src/helix_engine/vector_core/writer.rs | 67 +++---------------- 3 files changed, 54 insertions(+), 89 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/hnsw.rs b/helix-db/src/helix_engine/vector_core/hnsw.rs index efa75cdb4..b9b64b0f5 100644 --- a/helix-db/src/helix_engine/vector_core/hnsw.rs +++ b/helix-db/src/helix_engine/vector_core/hnsw.rs @@ -29,11 +29,11 @@ use crate::helix_engine::vector_core::{VectorCoreResult, VectorError}; pub(crate) type ScoredLink = (OrderedFloat, ItemId); -pub struct NodeState { - links: ArrayVec<[ScoredLink; M]>, +pub struct NodeState { + links: Vec, } -impl Debug for NodeState { +impl Debug for NodeState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // from [crate::unaligned_vector] struct Number(f32); @@ -53,19 +53,21 @@ impl Debug for NodeState { } } -pub struct HnswBuilder { +pub struct HnswBuilder { assign_probas: Vec, ef_construction: usize, alpha: f32, + m: usize, + m_max_0: usize, pub max_level: usize, pub entry_points: Vec, - pub layers: Vec>>, + pub layers: Vec>, distance: PhantomData, } -impl HnswBuilder { +impl HnswBuilder { pub fn new(opts: &BuildOption) -> Self { - let assign_probas = Self::get_default_probas(); + let assign_probas = Self::get_default_probas(opts.m); Self { assign_probas, ef_construction: opts.ef_construction, @@ -74,6 +76,8 @@ impl HnswBuilder { entry_points: Vec::new(), layers: vec![], distance: PhantomData, + m: opts.m, + m_max_0: opts.m_max_0, } } @@ -96,9 +100,9 @@ impl HnswBuilder { dist.sample(rng) } - fn get_default_probas() -> Vec { - let mut assign_probas = Vec::with_capacity(M); - let level_factor = 1.0 / (M as f32 + f32::EPSILON).ln(); + fn get_default_probas(m: usize) -> Vec { + let mut assign_probas = Vec::with_capacity(m); + let level_factor = 1.0 / (m as f32 + f32::EPSILON).ln(); let mut level = 0; loop { // P(L HnswBuilder { let _ = map_guard.insert( id, NodeState { - links: ArrayVec::from_iter(pruned), + links: Vec::from_iter(pruned), }, ); Ok(()) @@ -379,12 +383,8 @@ impl HnswBuilder { let Some(map) = self.layers.get(level) else { break; }; - map.pin().get_or_insert( - item_id, - NodeState { - links: array_vec![], - }, - ); + map.pin() + .get_or_insert(item_id, NodeState { links: vec![] }); } } @@ -501,26 +501,24 @@ impl HnswBuilder { let map_guard = map.pin(); // 'pure' links update function - let _add_link = |node_state: &NodeState| { - let mut links = node_state.links; - let cap = if level == 0 { M0 } else { M }; + let _add_link = |node_state: &NodeState| { + let mut links = node_state.links.clone(); + let cap = if level == 0 { self.m_max_0 } else { self.m }; - if links.len() < cap { + if node_state.links.len() < cap { links.push(q); return NodeState { links }; } let new_links = self - .robust_prune(links.to_vec(), level, self.alpha, lmdb) - .map(ArrayVec::from_iter) - .unwrap_or_else(|_| node_state.links); + .robust_prune(links, level, self.alpha, lmdb) + .map(Vec::from_iter) + .unwrap_or_else(|_| node_state.links.clone()); NodeState { links: new_links } }; - map_guard.update_or_insert_with(p, _add_link, || NodeState { - links: array_vec!([ScoredLink; M0] => q), - }); + map_guard.update_or_insert_with(p, _add_link, || NodeState { links: vec![q] }); Ok(()) } @@ -535,7 +533,7 @@ impl HnswBuilder { alpha: f32, lmdb: &FrozenReader<'_, D>, ) -> VectorCoreResult> { - let cap = if level == 0 { M0 } else { M }; + let cap = if level == 0 { self.m_max_0 } else { self.m }; candidates.sort_by(|a, b| b.cmp(a)); let mut selected: Vec = Vec::with_capacity(cap); diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index fa6dc31b4..e423f3636 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -45,6 +45,9 @@ pub mod unaligned_vector; pub mod version; pub mod writer; +const DB_VECTORS: &str = "vectors"; // for vector data (v:) +const DB_VECTOR_DATA: &str = "vector_data"; // for vector's properties + pub type ItemId = u32; pub type LayerId = u8; @@ -260,17 +263,28 @@ pub struct VectorCoreStats { // TODO: Properties filters // TODO: Support different distances for each database pub struct VectorCore { - /// One HNSW index per label - hsnw_index: HashMap>, + pub hsnw_index: CoreDatabase, pub stats: VectorCoreStats, pub vector_properties_db: Database, Bytes>, + pub config: HNSWConfig, } impl VectorCore { pub fn new(env: &Env, txn: &mut RwTxn, config: HNSWConfig) -> VectorCoreResult { - todo!() + let vectors_db: CoreDatabase = env.create_database(txn, Some(DB_VECTORS))?; + let vector_properties_db = env + .database_options() + .types::, Bytes>() + .name(DB_VECTOR_DATA) + .create(txn)?; + + Ok(Self { + hsnw_index: vectors_db, + stats: VectorCoreStats { num_vectors: 0 }, + vector_properties_db, + config, + }) } - pub fn search_by_vector<'a>(&self, txn: &RoTxn, vector: &'a [f32]) {} pub fn search<'arena>( &self, diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs index ad543fc92..c7bf62e45 100644 --- a/helix-db/src/helix_engine/vector_core/writer.rs +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -30,14 +30,18 @@ pub(crate) struct BuildOption { pub(crate) ef_construction: usize, pub(crate) alpha: f32, pub(crate) available_memory: Option, + pub(crate) m: usize, + pub(crate) m_max_0: usize, } -impl Default for BuildOption { +impl BuildOption { fn default() -> Self { Self { ef_construction: 100, alpha: 1.0, available_memory: None, + m: 16, + m_max_0: 32, } } } @@ -49,18 +53,6 @@ impl<'a, D: Distance, R: Rng + SeedableRng> VectorBuilder<'a, D, R> { /// Typical values range from 50 to 500, with larger `ef_construction` producing higher /// quality hnsw graphs at the expense of longer builds. The default value used in hannoy is /// 100. - /// - /// # Example - /// - /// ```no_run - /// # use hannoy::{Writer, distances::Euclidean}; - /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); - /// use rand::rngs::StdRng; - /// use rand::SeedableRng; - /// - /// let mut rng = StdRng::seed_from_u64(4729); - /// writer.builder(&mut rng).ef_construction(100).build::<16,32>(&mut wtxn); - /// ``` pub fn ef_construction(&mut self, ef_construction: usize) -> &mut Self { self.inner.ef_construction = ef_construction; self @@ -71,19 +63,7 @@ impl<'a, D: Distance, R: Rng + SeedableRng> VectorBuilder<'a, D, R> { /// more similar to DiskANN. Increasing alpha increases indexing times as more neighbours are /// considered per linking step, but results in higher recall. /// - /// DiskANN authors suggest using alpha=1.1 or alpha=1.2. By default alpha=1.0 in hannoy. - /// - /// # Example - /// - /// ```no_run - /// # use hannoy::{Writer, distances::Euclidean}; - /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); - /// use rand::rngs::StdRng; - /// use rand::SeedableRng; - /// - /// let mut rng = StdRng::seed_from_u64(4729); - /// writer.builder(&mut rng).alpha(1.1).build::<16,32>(&mut wtxn); - /// ``` + /// DiskANN authors suggest using alpha=1.1 or alpha=1.2. By default alpha=1.0. pub fn alpha(&mut self, alpha: f32) -> &mut Self { self.inner.alpha = alpha; self @@ -94,30 +74,8 @@ impl<'a, D: Distance, R: Rng + SeedableRng> VectorBuilder<'a, D, R> { /// A general rule of thumb is to take `M0`= 2*`M`, with `M` >=3. Some common choices for /// `M` include : 8, 12, 16, 32. Note that increasing `M` produces a denser graph at the cost /// of longer build times. - /// - /// This function is using rayon to spawn threads. It can be configured by using the - /// [`rayon::ThreadPoolBuilder`]. - /// - /// # Example - /// - /// ```no_run - /// # use hannoy::{Writer, distances::Euclidean}; - /// # let (writer, wtxn): (Writer, heed::RwTxn) = todo!(); - /// use rayon; - /// use rand::rngs::StdRng; - /// use rand::SeedableRng; - /// - /// // configure global threadpool if you want! - /// rayon::ThreadPoolBuilder::new().num_threads(4).build_global().unwrap(); - /// - /// let mut rng = StdRng::seed_from_u64(4729); - /// writer.builder(&mut rng).build::<16,32>(&mut wtxn); - /// ``` - pub fn build( - &mut self, - wtxn: &mut RwTxn, - ) -> VectorCoreResult<()> { - self.writer.build::(wtxn, self.rng, &self.inner) + pub fn build(&mut self, wtxn: &mut RwTxn) -> VectorCoreResult<()> { + self.writer.build::(wtxn, self.rng, &self.inner) } } @@ -258,12 +216,7 @@ impl Writer { } } - fn build( - &self, - wtxn: &mut RwTxn, - rng: &mut R, - options: &BuildOption, - ) -> VectorCoreResult<()> + fn build(&self, wtxn: &mut RwTxn, rng: &mut R, options: &BuildOption) -> VectorCoreResult<()> where R: Rng + SeedableRng, { @@ -292,7 +245,7 @@ impl Writer { // we should not keep a reference to the metadata since they're going to be moved by LMDB drop(metadata); - let mut hnsw = HnswBuilder::::new(options) + let mut hnsw = HnswBuilder::::new(options) .with_entry_points(entry_points) .with_max_level(max_level); From 1599f92a2392324f163f16fdac7f1d6d5f530dc1 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Wed, 19 Nov 2025 10:46:11 -0300 Subject: [PATCH 26/48] Switch from f64 to f32 for vector representation --- helix-db/src/helix_engine/bm25/bm25.rs | 22 ++-- helix-db/src/helix_engine/bm25/bm25_tests.rs | 13 +-- .../src/helix_engine/reranker/adapters/mod.rs | 12 +- .../src/helix_engine/reranker/fusion/mmr.rs | 62 +++++----- .../src/helix_engine/reranker/fusion/rrf.rs | 27 ++--- .../reranker/models/cross_encoder.rs | 9 +- .../src/helix_engine/reranker/reranker.rs | 18 +-- .../hnsw_concurrent_tests.rs | 52 +++++---- helix-db/src/helix_engine/tests/hnsw_tests.rs | 16 ++- .../src/helix_engine/tests/vector_tests.rs | 7 +- .../traversal_core/ops/bm25/search_bm25.rs | 2 +- .../ops/vectors/brute_force_search.rs | 15 +-- .../traversal_core/ops/vectors/insert.rs | 4 +- .../traversal_core/ops/vectors/search.rs | 6 +- .../traversal_core/traversal_value.rs | 6 +- helix-db/src/helix_engine/vector_core/mod.rs | 109 +++++++++++++----- helix-db/src/helix_engine/vector_core/node.rs | 4 - .../src/helix_engine/vector_core/writer.rs | 3 +- .../helix_gateway/embedding_providers/mod.rs | 17 +-- helix-db/src/helix_gateway/mcp/mcp.rs | 4 +- helix-db/src/helix_gateway/mcp/tools.rs | 4 +- .../custom_serde/compatibility_tests.rs | 4 +- .../protocol/custom_serde/edge_case_tests.rs | 25 ++-- .../custom_serde/error_handling_tests.rs | 10 +- .../custom_serde/integration_tests.rs | 4 +- .../custom_serde/property_based_tests.rs | 14 +-- .../src/protocol/custom_serde/test_utils.rs | 19 ++- .../src/protocol/custom_serde/vector_serde.rs | 6 +- .../custom_serde/vector_serde_tests.rs | 37 +++--- 29 files changed, 296 insertions(+), 235 deletions(-) diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 5709e9975..7fa58543f 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -386,7 +386,7 @@ pub trait HybridSearch { fn hybrid_search( self, query: &str, - query_vector: &[f64], + query_vector: &[f32], alpha: f32, limit: usize, ) -> impl std::future::Future, GraphError>> + Send; @@ -396,12 +396,11 @@ impl HybridSearch for HelixGraphStorage { async fn hybrid_search( self, query: &str, - query_vector: &[f64], + query_vector: &[f32], alpha: f32, limit: usize, ) -> Result, GraphError> { let query_owned = query.to_string(); - let query_vector_owned = query_vector.to_vec(); let graph_env_bm25 = self.graph_env.clone(); let graph_env_vector = self.graph_env.clone(); @@ -414,18 +413,23 @@ impl HybridSearch for HelixGraphStorage { } }); + let query_vector_owned = query_vector.to_vec(); let vector_handle = - task::spawn_blocking(move || -> Result>, GraphError> { + task::spawn_blocking(move || -> Result>, GraphError> { let txn = graph_env_vector.read_txn()?; let arena = Bump::new(); // MOVE - let query_slice = arena.alloc_slice_copy(query_vector_owned.as_slice()); - let results = - self.vectors - .search(&txn, query_slice, limit * 2, "vector", false, &arena)?; + let results = self.vectors.search( + &txn, + query_vector_owned, + limit * 2, + "vector", + false, + &arena, + )?; let scores = results .into_iter() .map(|vec| (vec.id, vec.distance.unwrap_or(0.0))) - .collect::>(); + .collect::>(); Ok(Some(scores)) }); diff --git a/helix-db/src/helix_engine/bm25/bm25_tests.rs b/helix-db/src/helix_engine/bm25/bm25_tests.rs index 97c8e8535..1de78117c 100644 --- a/helix-db/src/helix_engine/bm25/bm25_tests.rs +++ b/helix-db/src/helix_engine/bm25/bm25_tests.rs @@ -49,14 +49,14 @@ mod tests { (storage, temp_dir) } - fn generate_random_vectors(n: usize, d: usize) -> Vec> { + fn generate_random_vectors(n: usize, d: usize) -> Vec> { let mut rng = rand::rng(); let mut vectors = Vec::with_capacity(n); for _ in 0..n { let mut vector = Vec::with_capacity(d); for _ in 0..d { - vector.push(rng.random::()); + vector.push(rng.random::()); } vectors.push(vector); } @@ -1431,7 +1431,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search() { - let (storage, _temp_dir) = setup_helix_storage(); + let (mut storage, _temp_dir) = setup_helix_storage(); let mut wtxn = storage.graph_env.write_txn().unwrap(); let docs = vec![ @@ -1450,10 +1450,9 @@ mod tests { let vectors = generate_random_vectors(800, 650); let mut arena = Bump::new(); for vec in &vectors { - let slice = arena.alloc_slice_copy(vec.as_slice()); let _ = storage .vectors - .insert(&mut wtxn, "vector", slice, None, &arena); + .insert(&mut wtxn, "vector", vec.as_slice(), None, &arena); arena.reset(); } wtxn.commit().unwrap(); @@ -1475,7 +1474,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search_alpha_vectors() { - let (storage, _temp_dir) = setup_helix_storage(); + let (mut storage, _temp_dir) = setup_helix_storage(); // Insert some test documents first let mut wtxn = storage.graph_env.write_txn().unwrap(); @@ -1521,7 +1520,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search_alpha_bm25() { - let (storage, _temp_dir) = setup_helix_storage(); + let (mut storage, _temp_dir) = setup_helix_storage(); // Insert some test documents first let mut wtxn = storage.graph_env.write_txn().unwrap(); diff --git a/helix-db/src/helix_engine/reranker/adapters/mod.rs b/helix-db/src/helix_engine/reranker/adapters/mod.rs index 145ad5e68..fada0b5db 100644 --- a/helix-db/src/helix_engine/reranker/adapters/mod.rs +++ b/helix-db/src/helix_engine/reranker/adapters/mod.rs @@ -131,15 +131,11 @@ mod tests { #[test] fn test_rerank_iterator() { let arena = bumpalo::Bump::new(); - let data1 = arena.alloc_slice_copy(&[1.0]); - let data2 = arena.alloc_slice_copy(&[2.0]); + let data1 = bumpalo::vec![in &arena; 1.0]; + let data2 = bumpalo::vec![in &arena; 2.0]; let items = vec![ - Ok(TraversalValue::Vector(HVector::from_slice( - "test", 0, data1, &arena, - ))), - Ok(TraversalValue::Vector(HVector::from_slice( - "test", 0, data2, &arena, - ))), + Ok(TraversalValue::Vector(HVector::from_vec("test", data1))), + Ok(TraversalValue::Vector(HVector::from_vec("test", data2))), ]; let mut iter = RerankIterator { diff --git a/helix-db/src/helix_engine/reranker/fusion/mmr.rs b/helix-db/src/helix_engine/reranker/fusion/mmr.rs index 15316802f..f96e80983 100644 --- a/helix-db/src/helix_engine/reranker/fusion/mmr.rs +++ b/helix-db/src/helix_engine/reranker/fusion/mmr.rs @@ -35,18 +35,18 @@ pub struct MMRReranker { /// Lambda parameter: controls relevance vs diversity trade-off /// Higher values (closer to 1.0) favor relevance /// Lower values (closer to 0.0) favor diversity - lambda: f64, + lambda: f32, /// Distance metric for similarity calculation distance_method: DistanceMethod, /// Optional query vector for relevance calculation - query_vector: Option>, + query_vector: Option>, } impl MMRReranker { /// Create a new MMR reranker with default lambda=0.7 (favoring relevance). - pub fn new(lambda: f64) -> RerankerResult { + pub fn new(lambda: f32) -> RerankerResult { if !(0.0..=1.0).contains(&lambda) { return Err(RerankerError::InvalidParameter( "lambda must be between 0.0 and 1.0".to_string(), @@ -61,7 +61,7 @@ impl MMRReranker { } /// Create an MMR reranker with a custom distance metric. - pub fn with_distance(lambda: f64, distance_method: DistanceMethod) -> RerankerResult { + pub fn with_distance(lambda: f32, distance_method: DistanceMethod) -> RerankerResult { if !(0.0..=1.0).contains(&lambda) { return Err(RerankerError::InvalidParameter( "lambda must be between 0.0 and 1.0".to_string(), @@ -76,20 +76,23 @@ impl MMRReranker { } /// Set the query vector for relevance calculation. - pub fn with_query_vector(mut self, query: Vec) -> Self { + pub fn with_query_vector(mut self, query: Vec) -> Self { self.query_vector = Some(query); self } /// Extract vector data from a TraversalValue. - /// Note: This requires an arena to convert VectorPrecisionData to f64 slice fn extract_vector_data<'a>( &self, item: &'a TraversalValue<'a>, arena: &'a bumpalo::Bump, - ) -> RerankerResult> { + ) -> RerankerResult> { match item { - TraversalValue::Vector(v) => Ok(v.data(arena).to_vec()), + TraversalValue::Vector(v) => { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(v.data_borrowed()); + Ok(bump_vec) + } _ => Err(RerankerError::TextExtractionError( "Cannot extract vector from this item type (only Vector supported for MMR)" .to_string(), @@ -98,7 +101,7 @@ impl MMRReranker { } /// Calculate similarity between two items. - fn calculate_similarity(&self, item1: &[f64], item2: &[f64]) -> RerankerResult { + fn calculate_similarity(&self, item1: &[f32], item2: &[f32]) -> RerankerResult { if item1.len() != item2.len() { return Err(RerankerError::InvalidParameter( "Vector dimensions must match".to_string(), @@ -108,9 +111,13 @@ impl MMRReranker { let distance = match self.distance_method { DistanceMethod::Cosine => { // Calculate cosine similarity (1 - cosine distance) - let dot_product: f64 = item1.iter().zip(item2.iter()).map(|(a, b)| a * b).sum(); - let norm1: f64 = item1.iter().map(|x| x * x).sum::().sqrt(); - let norm2: f64 = item2.iter().map(|x| x * x).sum::().sqrt(); + let dot_product = item1 + .iter() + .zip(item2.iter()) + .map(|(a, b)| a * b) + .sum::(); + let norm1 = item1.iter().map(|x| x * x).sum::().sqrt(); + let norm2 = item2.iter().map(|x| x * x).sum::().sqrt(); if norm1 == 0.0 || norm2 == 0.0 { 0.0 @@ -120,11 +127,11 @@ impl MMRReranker { } DistanceMethod::Euclidean => { // Convert Euclidean distance to similarity (using negative exponential) - let dist_sq: f64 = item1 + let dist_sq = item1 .iter() .zip(item2.iter()) .map(|(a, b)| (a - b).powi(2)) - .sum(); + .sum::(); (-dist_sq.sqrt()).exp() } DistanceMethod::DotProduct => { @@ -149,7 +156,7 @@ impl MMRReranker { let n = items.len(); let mut selected: Vec> = Vec::with_capacity(n); - let mut remaining: Vec<(TraversalValue<'arena>, f64)> = Vec::with_capacity(n); + let mut remaining: Vec<(TraversalValue<'arena>, f32)> = Vec::with_capacity(n); // Extract original scores and prepare remaining items for item in items { @@ -158,7 +165,7 @@ impl MMRReranker { } // Cache for similarity calculations - let mut similarity_cache: HashMap<(usize, usize), f64> = HashMap::new(); + let mut similarity_cache: HashMap<(usize, usize), f32> = HashMap::new(); // Select first item (highest original score) remaining.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); @@ -168,7 +175,7 @@ impl MMRReranker { // Iteratively select remaining items while !remaining.is_empty() { let mut best_idx = 0; - let mut best_mmr_score = f64::NEG_INFINITY; + let mut best_mmr_score = f32::NEG_INFINITY; for (idx, (item, relevance_score)) in remaining.iter().enumerate() { let item_vec = self.extract_vector_data(item, &arena)?; @@ -181,7 +188,7 @@ impl MMRReranker { }; // Calculate diversity term (max similarity to selected items) - let mut max_similarity: f64 = 0.0; + let mut max_similarity: f32 = 0.0; for (sel_idx, selected_item) in selected.iter().enumerate() { // Check cache first let cache_key = (idx, sel_idx); @@ -240,9 +247,10 @@ mod tests { use crate::helix_engine::vector_core::HVector; use bumpalo::Bump; - fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { - let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice, arena) + fn alloc_vector<'a>(arena: &'a Bump, data: &[f32]) -> HVector<'a> { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); + HVector::from_vec("test_vector", bump_vec) } #[test] @@ -648,9 +656,9 @@ mod tests { // Create 100 vectors let vectors: Vec = (0..100) .map(|i| { - let angle = (i as f64) * 0.1; + let angle = (i as f32) * 0.1; let mut v = alloc_vector(&arena, &[angle.cos(), angle.sin()]); - v.distance = Some(1.0 - i as f64 / 100.0); + v.distance = Some(1.0 - i as f32 / 100.0); v.id = i as u128; TraversalValue::Vector(v) }) @@ -705,8 +713,8 @@ mod tests { let vectors: Vec = (0..3) .map(|i| { - let mut v = alloc_vector(&arena, &[1.0 * i as f64, 0.0]); - v.distance = Some(1.0 - i as f64 * 0.1); + let mut v = alloc_vector(&arena, &[1.0 * i as f32, 0.0]); + v.distance = Some(1.0 - i as f32 * 0.1); v.id = i as u128; TraversalValue::Vector(v) }) @@ -768,9 +776,9 @@ mod tests { let vectors: Vec = (0..5) .map(|i| { - let data: Vec = (0..100).map(|j| if j == i { 1.0 } else { 0.0 }).collect(); + let data: Vec = (0..100).map(|j| if j == i { 1.0 } else { 0.0 }).collect(); let mut v = alloc_vector(&arena, &data); - v.distance = Some(1.0 - i as f64 * 0.1); + v.distance = Some(1.0 - i as f32 * 0.1); v.id = i as u128; TraversalValue::Vector(v) }) diff --git a/helix-db/src/helix_engine/reranker/fusion/rrf.rs b/helix-db/src/helix_engine/reranker/fusion/rrf.rs index b93703497..750d59c56 100644 --- a/helix-db/src/helix_engine/reranker/fusion/rrf.rs +++ b/helix-db/src/helix_engine/reranker/fusion/rrf.rs @@ -23,7 +23,7 @@ use std::collections::HashMap; #[derive(Debug, Clone)] pub struct RRFReranker { /// The k parameter in the RRF formula (default: 60) - k: f64, + k: f32, } impl RRFReranker { @@ -36,7 +36,7 @@ impl RRFReranker { /// /// # Arguments /// * `k` - The k parameter in the RRF formula. Higher values give less weight to ranking position. - pub fn with_k(k: f64) -> RerankerResult { + pub fn with_k(k: f32) -> RerankerResult { if k <= 0.0 { return Err(RerankerError::InvalidParameter( "k must be positive".to_string(), @@ -55,7 +55,7 @@ impl RRFReranker { /// A vector of items reranked by RRF scores pub fn fuse_lists<'arena, I>( lists: Vec, - k: f64, + k: f32, ) -> RerankerResult>> where I: Iterator>, @@ -64,7 +64,7 @@ impl RRFReranker { return Err(RerankerError::EmptyInput); } - let mut rrf_scores: HashMap = HashMap::new(); + let mut rrf_scores: HashMap = HashMap::new(); let mut items_map: HashMap> = HashMap::new(); // Process each ranked list @@ -79,7 +79,7 @@ impl RRFReranker { // Calculate reciprocal rank: 1 / (k + rank) // rank starts at 0, so actual rank is rank + 1 - let rr_score = 1.0 / (k + (rank as f64) + 1.0); + let rr_score = 1.0 / (k + (rank as f32) + 1.0); // Sum reciprocal ranks across all lists *rrf_scores.entry(id).or_insert(0.0) += rr_score; @@ -90,7 +90,7 @@ impl RRFReranker { } // Convert to scored items and sort by RRF score (descending) - let mut scored_items: Vec<(u128, f64)> = rrf_scores.into_iter().collect(); + let mut scored_items: Vec<(u128, f32)> = rrf_scores.into_iter().collect(); scored_items.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); // Update scores and collect results @@ -132,7 +132,7 @@ impl Reranker for RRFReranker { for (rank, mut item) in items_vec.into_iter().enumerate() { // Calculate RRF score for this item based on its rank - let rrf_score = 1.0 / (self.k + (rank as f64) + 1.0); + let rrf_score = 1.0 / (self.k + (rank as f32) + 1.0); update_score(&mut item, rrf_score)?; results.push(item); } @@ -151,9 +151,10 @@ mod tests { use crate::{helix_engine::vector_core::HVector, utils::items::Node}; use bumpalo::Bump; - fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { - let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice, arena) + fn alloc_vector<'a>(arena: &'a Bump, data: &[f32]) -> HVector<'a> { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); + HVector::from_vec("test_vector", bump_vec) } #[test] @@ -164,7 +165,7 @@ mod tests { let vectors: Vec = (0..5) .map(|i| { let mut v = alloc_vector(&arena, &[1.0, 2.0, 3.0]); - v.distance = Some((i + 1) as f64); + v.distance = Some((i + 1) as f32); v.id = i as u128; TraversalValue::Vector(v) }) @@ -177,7 +178,7 @@ mod tests { // Check that RRF scores are calculated correctly for (rank, item) in results.iter().enumerate() { if let TraversalValue::Vector(v) = item { - let expected_score = 1.0 / (60.0 + (rank as f64) + 1.0); + let expected_score = 1.0 / (60.0 + (rank as f32) + 1.0); assert!((v.distance.unwrap() - expected_score).abs() < 1e-10); } } @@ -589,7 +590,7 @@ mod tests { let vectors: Vec = (0..3) .map(|i| { - let mut v = alloc_vector(&arena, &[1.0 * i as f64, 2.0 * i as f64]); + let mut v = alloc_vector(&arena, &[1.0 * i as f32, 2.0 * i as f32]); v.id = i as u128; TraversalValue::Vector(v) }) diff --git a/helix-db/src/helix_engine/reranker/models/cross_encoder.rs b/helix-db/src/helix_engine/reranker/models/cross_encoder.rs index db3feb74b..bc78b95fc 100644 --- a/helix-db/src/helix_engine/reranker/models/cross_encoder.rs +++ b/helix-db/src/helix_engine/reranker/models/cross_encoder.rs @@ -124,7 +124,7 @@ impl CrossEncoderReranker { /// /// This is a placeholder for actual model inference. /// TODO: Implement actual model loading and inference. - fn score_pair(&self, _query: &str, _document: &str) -> RerankerResult { + fn score_pair(&self, _query: &str, _document: &str) -> RerankerResult { todo!(); } } @@ -175,9 +175,10 @@ mod tests { use crate::helix_engine::vector_core::HVector; use bumpalo::Bump; - fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { - let slice = arena.alloc_slice_copy(data); - HVector::from_slice("test_vector", 0, slice, arena) + fn alloc_vector<'a>(arena: &'a Bump, data: &[f32]) -> HVector<'a> { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); + HVector::from_vec("test_vector", bump_vec) } #[ignore] diff --git a/helix-db/src/helix_engine/reranker/reranker.rs b/helix-db/src/helix_engine/reranker/reranker.rs index 2642c6f61..eb8f6024e 100644 --- a/helix-db/src/helix_engine/reranker/reranker.rs +++ b/helix-db/src/helix_engine/reranker/reranker.rs @@ -3,11 +3,9 @@ //! Core Reranker trait and related types. -use crate::{ - helix_engine::{ - reranker::errors::{RerankerError, RerankerResult}, - traversal_core::traversal_value::TraversalValue, - }, +use crate::helix_engine::{ + reranker::errors::{RerankerError, RerankerResult}, + traversal_core::traversal_value::TraversalValue, }; /// Represents a scored item for reranking. @@ -41,7 +39,11 @@ pub trait Reranker: Send + Sync { /// /// # Returns /// A vector of reranked items with updated scores - fn rerank<'arena, I>(&self, items: I, query: Option<&str>) -> RerankerResult>> + fn rerank<'arena, I>( + &self, + items: I, + query: Option<&str>, + ) -> RerankerResult>> where I: Iterator>; @@ -53,7 +55,7 @@ pub trait Reranker: Send + Sync { /// /// This handles the different types (Node, Edge, Vector) and extracts /// their associated score/distance value. -pub fn extract_score(item: &TraversalValue) -> RerankerResult { +pub fn extract_score(item: &TraversalValue) -> RerankerResult { match item { TraversalValue::Vector(v) => Ok(v.score()), TraversalValue::NodeWithScore { score, .. } => Ok(*score), @@ -69,7 +71,7 @@ pub fn extract_score(item: &TraversalValue) -> RerankerResult { /// /// This modifies the distance/score field of the item to reflect /// the new reranked score. -pub fn update_score(item: &mut TraversalValue, new_score: f64) -> RerankerResult<()> { +pub fn update_score(item: &mut TraversalValue, new_score: f32) -> RerankerResult<()> { match item { TraversalValue::Vector(v) => { v.distance = Some(new_score); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs index 73830d0aa..5eb75b7cd 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs @@ -44,7 +44,7 @@ fn setup_concurrent_env() -> (TempDir, Env) { } /// Generate a random vector of given dimensionality -fn random_vector(dim: usize) -> Vec { +fn random_vector(dim: usize) -> Vec { (0..dim) .map(|_| rand::rng().random_range(0.0..1.0)) .collect() @@ -96,12 +96,17 @@ fn test_concurrent_inserts_single_label() { let mut wtxn = env.write_txn().unwrap(); let arena = Bump::new(); let vector = random_vector(128); - let data = arena.alloc_slice_copy(&vector); // Open the existing databases and insert - let index = open_vector_core(&env, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env, &mut wtxn).unwrap(); index - .insert(&mut wtxn, "concurrent_test", data, None, &arena) + .insert( + &mut wtxn, + "concurrent_test", + vector.as_slice(), + None, + &arena, + ) .expect("Insert should succeed"); wtxn.commit().expect("Commit should succeed"); } @@ -134,7 +139,7 @@ fn test_concurrent_inserts_single_label() { // Additional consistency check: Verify we can perform searches (entry point exists implicitly) let arena = Bump::new(); let query = [0.5; 128]; - let search_result = index.search(&rtxn, &query, 10, "concurrent_test", false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 10, "concurrent_test", false, &arena); assert!( search_result.is_ok(), "Should be able to search after concurrent inserts (entry point exists)" @@ -156,7 +161,7 @@ fn test_concurrent_searches_during_inserts() { // Initialize with some initial vectors { let mut txn = env.write_txn().unwrap(); - let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let arena = Bump::new(); for _ in 0..50 { @@ -198,7 +203,7 @@ fn test_concurrent_searches_during_inserts() { let rtxn = env.read_txn().unwrap(); let arena = Bump::new(); - match index.search(&rtxn, &query[..], 10, "search_test", false, &arena) { + match index.search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) { Ok(results) => { total_searches += 1; total_results += results.len(); @@ -245,7 +250,7 @@ fn test_concurrent_searches_during_inserts() { let vector = random_vector(128); let data = arena.alloc_slice_copy(&vector); - let index = open_vector_core(&env, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env, &mut wtxn).unwrap(); index .insert(&mut wtxn, "search_test", data, None, &arena) .expect("Insert should succeed"); @@ -277,7 +282,7 @@ fn test_concurrent_searches_during_inserts() { // Verify we can still search successfully let arena = Bump::new(); let results = index - .search(&rtxn, &query[..], 10, "search_test", false, &arena) + .search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) .unwrap(); assert!( !results.is_empty(), @@ -317,7 +322,7 @@ fn test_concurrent_inserts_multiple_labels() { for i in 0..vectors_per_label { let mut wtxn = env.write_txn().unwrap(); - let index = open_vector_core(&env, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(64); @@ -350,7 +355,7 @@ fn test_concurrent_inserts_multiple_labels() { // Verify we can search for each label (entry point exists implicitly) let query = [0.5; 64]; - let search_result = index.search(&rtxn, &query, 5, &label, false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 5, &label, false, &arena); assert!( search_result.is_ok(), "Should be able to search label {}", @@ -402,7 +407,7 @@ fn test_entry_point_consistency() { for _ in 0..vectors_per_thread { let mut wtxn = env.write_txn().unwrap(); - let index = open_vector_core(&env, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(32); @@ -430,7 +435,7 @@ fn test_entry_point_consistency() { // If we can successfully search, entry point must be valid let query = [0.5; 32]; - let search_result = index.search(&rtxn, &query, 10, "entry_test", false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 10, "entry_test", false, &arena); assert!( search_result.is_ok(), "Entry point should exist and be valid" @@ -483,7 +488,7 @@ fn test_graph_connectivity_after_concurrent_inserts() { for _ in 0..vectors_per_thread { let mut wtxn = env.write_txn().unwrap(); - let index = open_vector_core(&env, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(64); @@ -513,7 +518,14 @@ fn test_graph_connectivity_after_concurrent_inserts() { for i in 0..10 { let query = random_vector(64); let results = index - .search(&rtxn, &query, 10, "connectivity_test", false, &arena) + .search( + &rtxn, + query.to_vec(), + 10, + "connectivity_test", + false, + &arena, + ) .unwrap(); assert!( @@ -545,14 +557,13 @@ fn test_transaction_isolation() { let initial_count = 10; { let mut txn = env.write_txn().unwrap(); - let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let arena = Bump::new(); for _ in 0..initial_count { let vector = random_vector(32); - let data = arena.alloc_slice_copy(&vector); index - .insert(&mut txn, "isolation_test", data, None, &arena) + .insert(&mut txn, "isolation_test", vector.as_slice(), None, &arena) .unwrap(); } txn.commit().unwrap(); @@ -580,13 +591,12 @@ fn test_transaction_isolation() { let handle = thread::spawn(move || { for _ in 0..20 { let mut wtxn = env_clone.write_txn().unwrap(); - let index = open_vector_core(&env_clone, &mut wtxn).unwrap(); + let mut index = open_vector_core(&env_clone, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(32); - let data = arena.alloc_slice_copy(&vector); index - .insert(&mut wtxn, "isolation_test", data, None, &arena) + .insert(&mut wtxn, "isolation_test", vector.as_slice(), None, &arena) .unwrap(); wtxn.commit().unwrap(); } diff --git a/helix-db/src/helix_engine/tests/hnsw_tests.rs b/helix-db/src/helix_engine/tests/hnsw_tests.rs index a0668eb53..0c7c460fd 100644 --- a/helix-db/src/helix_engine/tests/hnsw_tests.rs +++ b/helix-db/src/helix_engine/tests/hnsw_tests.rs @@ -25,14 +25,13 @@ fn setup_env() -> (Env, TempDir) { fn test_hnsw_insert_and_count() { let (env, _temp_dir) = setup_env(); let mut txn = env.write_txn().unwrap(); - let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); - let vector: Vec = (0..4).map(|_| rand::rng().random_range(0.0..1.0)).collect(); + let vector: Vec = (0..4).map(|_| rand::rng().random_range(0.0..1.0)).collect(); for _ in 0..10 { let arena = Bump::new(); - let data = arena.alloc_slice_copy(&vector); let _ = index - .insert(&mut txn, "vector", data, None, &arena) + .insert(&mut txn, "vector", vector.as_slice(), None, &arena) .unwrap(); } @@ -44,15 +43,14 @@ fn test_hnsw_insert_and_count() { fn test_hnsw_search_returns_results() { let (env, _temp_dir) = setup_env(); let mut txn = env.write_txn().unwrap(); - let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let mut rng = rand::rng(); for _ in 0..128 { let arena = Bump::new(); - let vector: Vec = (0..4).map(|_| rng.random_range(0.0..1.0)).collect(); - let data = arena.alloc_slice_copy(&vector); + let vector: Vec = (0..4).map(|_| rng.random_range(0.0..1.0)).collect(); let _ = index - .insert(&mut txn, "vector", data, None, &arena) + .insert(&mut txn, "vector", vector.as_slice(), None, &arena) .unwrap(); } txn.commit().unwrap(); @@ -61,7 +59,7 @@ fn test_hnsw_search_returns_results() { let txn = env.read_txn().unwrap(); let query = [0.5, 0.5, 0.5, 0.5]; let results = index - .search(&txn, &query, 5, "vector", false, &arena) + .search(&txn, query.to_vec(), 5, "vector", false, &arena) .unwrap(); assert!(!results.is_empty()); } diff --git a/helix-db/src/helix_engine/tests/vector_tests.rs b/helix-db/src/helix_engine/tests/vector_tests.rs index bda172d8a..48547db57 100644 --- a/helix-db/src/helix_engine/tests/vector_tests.rs +++ b/helix-db/src/helix_engine/tests/vector_tests.rs @@ -4,9 +4,10 @@ use crate::helix_engine::vector_core::{ }; use bumpalo::Bump; -fn alloc_vector<'a>(arena: &'a Bump, data: &[f64]) -> HVector<'a> { - let slice = arena.alloc_slice_copy(data); - HVector::from_slice("vector", 0, slice, arena) +fn alloc_vector<'a>(arena: &'a Bump, data: &[f32]) -> HVector<'a> { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); + HVector::from_vec("test_vector", bump_vec) } #[test] diff --git a/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs b/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs index 5114e0eef..8933a4fd5 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs @@ -82,7 +82,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE if label_in_lmdb == label_as_bytes { match Node::<'arena>::from_bincode_bytes(id, value, self.arena) { Ok(node) => { - return Some(Ok(TraversalValue::NodeWithScore { node, score: score as f64 })); + return Some(Ok(TraversalValue::NodeWithScore { node, score: score })); } Err(e) => { println!("{} Error decoding node: {:?}", line!(), e); diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs index 8273a8713..670cc6c24 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs @@ -13,7 +13,7 @@ pub trait BruteForceSearchVAdapter<'db, 'arena, 'txn>: { fn brute_force_search_v( self, - query: &'arena [f64], + query: &'arena [f32], k: K, ) -> RoTraversalIterator< 'db, @@ -31,7 +31,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE { fn brute_force_search_v( self, - query: &'arena [f64], + query: &'arena [f32], k: K, ) -> RoTraversalIterator< 'db, @@ -48,11 +48,12 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE .inner .filter_map(|v| match v { Ok(TraversalValue::Vector(mut v)) => { - let d = Cosine::distance( - v.data.as_ref().unwrap(), - &Item::::from(query, &arena), - ); - v.set_distance(d as f64); + let mut bump_vec = bumpalo::collections::Vec::new_in(&self.arena); + bump_vec.extend_from_slice(v.data_borrowed()); + + let d = + Cosine::distance(v.data.as_ref().unwrap(), &Item::::new(bump_vec)); + v.set_distance(d); Some(v) } _ => None, diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs index 50f257a18..7125533ec 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/insert.rs @@ -12,7 +12,7 @@ pub trait InsertVAdapter<'db, 'arena, 'txn>: { fn insert_v( self, - query: &'arena [f64], + query: &'arena [f32], label: &'arena str, properties: Option>, ) -> RwTraversalIterator< @@ -28,7 +28,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE { fn insert_v( self, - query: &'arena [f64], + query: &'arena [f32], label: &'arena str, properties: Option>, ) -> RwTraversalIterator< diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index 1f414f7e8..61f1ed896 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -12,7 +12,7 @@ pub trait SearchVAdapter<'db, 'arena, 'txn>: { fn search_v( self, - query: &'arena [f64], + query: &'arena [f32], k: K, label: &'arena str, filter: Option<&'arena [F]>, @@ -33,7 +33,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE { fn search_v( self, - query: &'arena [f64], + query: &'arena [f32], k: K, label: &'arena str, filter: Option<&'arena [F]>, @@ -50,7 +50,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE { let vectors = self.storage.vectors.search( self.txn, - query, + query.to_vec(), k.try_into().unwrap(), label, false, diff --git a/helix-db/src/helix_engine/traversal_core/traversal_value.rs b/helix-db/src/helix_engine/traversal_core/traversal_value.rs index 8f28ba02b..bd1689bf1 100644 --- a/helix-db/src/helix_engine/traversal_core/traversal_value.rs +++ b/helix-db/src/helix_engine/traversal_core/traversal_value.rs @@ -25,7 +25,7 @@ pub enum TraversalValue<'arena> { Value(Value), /// Item With Score - NodeWithScore { node: Node<'arena>, score: f64 }, + NodeWithScore { node: Node<'arena>, score: f32 }, /// An empty traversal value Empty, } @@ -65,14 +65,14 @@ impl<'arena> TraversalValue<'arena> { } } - pub fn data(&'arena self) -> &'arena [f64] { + pub fn data(&'arena self) -> &'arena [f32] { match self { TraversalValue::Vector(vector) => vector.data_borrowed(), _ => unimplemented!(), } } - pub fn score(&self) -> f64 { + pub fn score(&self) -> f32 { match self { TraversalValue::Vector(vector) => vector.score(), _ => unimplemented!(), diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index e423f3636..f91550137 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, cmp::Ordering}; +use std::{borrow::Cow, cmp::Ordering, hash::Hash}; use bincode::Options; use byteorder::BE; @@ -7,6 +7,7 @@ use heed3::{ Database, Env, Error as LmdbError, RoTxn, RwTxn, types::{Bytes, U128}, }; +use rand::{SeedableRng, rngs::StdRng}; use serde::{Deserialize, Serialize}; use crate::{ @@ -61,47 +62,33 @@ pub type CoreDatabase = heed3::Database>; #[derive(Debug, Serialize, Clone)] pub struct HVector<'arena> { pub id: u128, - pub distance: Option, + pub distance: Option, pub label: &'arena str, pub deleted: bool, pub version: u8, - pub level: usize, pub properties: Option>, pub data: Option>, } impl<'arena> HVector<'arena> { - // FIXME: this allocates twice - pub fn data(&self, arena: &'arena bumpalo::Bump) -> &'arena [f64] { - let vec_f32 = self.data.as_ref().unwrap().vector.as_ref().to_vec(arena); - - arena.alloc_slice_fill_iter(vec_f32.iter().map(|&x| x as f64)) - } - - pub fn data_borrowed(&self) -> &[f64] { - bytemuck::cast_slice(self.data.as_ref().unwrap().vector.as_ref().as_bytes()) + pub fn data_borrowed(&self) -> &[f32] { + bytemuck::cast_slice(self.data.as_ref().unwrap().vector.as_bytes()) } - pub fn from_slice( - label: &'arena str, - level: usize, - data: &'arena [f64], - arena: &'arena bumpalo::Bump, - ) -> Self { + pub fn from_vec(label: &'arena str, data: bumpalo::collections::Vec<'arena, f32>) -> Self { let id = v6_uuid(); HVector { id, version: 1, - level, label, - data: Some(Item::::from(data, arena)), + data: Some(Item::::new(data)), distance: None, properties: None, deleted: false, } } - pub fn score(&self) -> f64 { + pub fn score(&self) -> f32 { self.distance.unwrap_or(2.0) } @@ -163,11 +150,11 @@ impl<'arena> HVector<'arena> { todo!() } - pub fn set_distance(&mut self, distance: f64) { + pub fn set_distance(&mut self, distance: f32) { self.distance = Some(distance); } - pub fn get_distance(&self) -> f64 { + pub fn get_distance(&self) -> f32 { self.distance.unwrap() } @@ -184,11 +171,13 @@ impl<'arena> HVector<'arena> { self.properties.as_ref().and_then(|value| value.get(key)) } - pub fn cast_raw_vector_data<'txn>( - arena: &'arena bumpalo::Bump, + pub fn raw_vector_data_to_vec<'txn>( raw_vector_data: &'txn [u8], - ) -> &'txn [f64] { - todo!() + arena: &'arena bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, f32> { + let mut bump_vec = bumpalo::collections::Vec::<'arena, f32>::new_in(arena); + bump_vec.extend_from_slice(bytemuck::cast_slice(raw_vector_data)); + bump_vec } pub fn from_raw_vector_data<'txn>( @@ -256,7 +245,7 @@ impl HNSWConfig { } pub struct VectorCoreStats { - // Do it atomical? + // Do it an atomic? pub num_vectors: usize, } @@ -267,6 +256,19 @@ pub struct VectorCore { pub stats: VectorCoreStats, pub vector_properties_db: Database, Bytes>, pub config: HNSWConfig, + + /// Map labels to a different index + pub label_to_index: HashMap, + /// Track the last index + curr_index: u16, + + pub id_map: HashMap, + curr_id: u32, + + /// The actual index is lazily build during the first access + // TODO: We should choose a better strategy to build the index + writer: Option>, + reader: Option>, } impl VectorCore { @@ -283,29 +285,74 @@ impl VectorCore { stats: VectorCoreStats { num_vectors: 0 }, vector_properties_db, config, + writer: None, + reader: None, + label_to_index: HashMap::new(), + curr_index: 0, + id_map: HashMap::new(), + curr_id: 0, }) } pub fn search<'arena>( &self, txn: &RoTxn, - query: &'arena [f64], + query: Vec, k: usize, label: &'arena str, should_trickle: bool, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { + if self.reader.is_none() { + return Ok(bumpalo::collections::Vec::new_in(&arena)); + } todo!() } + fn get_or_create(&mut self, label: &str) -> u16 { + if let Some(&index) = self.label_to_index.get(label) { + index + } else { + self.curr_index += 1; + self.label_to_index + .insert(label.to_string(), self.curr_index); + self.curr_index + } + } + pub fn insert<'arena>( - &self, + &mut self, txn: &mut RwTxn, label: &'arena str, - data: &'arena [f64], + data: &'arena [f32], properties: Option>, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { + // index hasn't been built yet + if self.writer.is_none() { + let idx = self.get_or_create(label); + // assume the len of the first insertion as the + // index dimension + self.writer = Some(Writer::new(self.hsnw_index, idx, data.len())); + let mut rng = StdRng::from_os_rng(); + let mut builder = self.writer.as_ref().unwrap().builder(&mut rng); + builder + .ef_construction(self.config.ef_construct) + .build(txn)?; + } + + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); + let hvector = HVector::from_vec(label, bump_vec); + + self.curr_id += 1; + self.id_map.insert(hvector.id, self.curr_id); + + self.writer + .as_ref() + .unwrap() + .add_item(txn, self.curr_id, data); + todo!() } diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs index b2a7948c5..b0e6164b5 100644 --- a/helix-db/src/helix_engine/vector_core/node.rs +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -83,10 +83,6 @@ impl Item<'_, D> { let header = D::new_header(&vector); Self { header, vector } } - - pub fn from<'arena>(vec: &[f64], arena: &'arena bumpalo::Bump) -> Self { - Self::new(vec.into_iter().map(|x| *x as f32).collect_in(arena)) - } } #[derive(Clone, Debug)] diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs index c7bf62e45..a09bc4986 100644 --- a/helix-db/src/helix_engine/vector_core/writer.rs +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -47,8 +47,7 @@ impl BuildOption { } impl<'a, D: Distance, R: Rng + SeedableRng> VectorBuilder<'a, D, R> { - /// Controls the search range when inserting a new item into the graph. This value must be - /// greater than or equal to the `M` used in [`Self::build`] + /// Controls the search range when inserting a new item into the graph. /// /// Typical values range from 50 to 500, with larger `ef_construction` producing higher /// quality hnsw graphs at the expense of longer builds. The default value used in hannoy is diff --git a/helix-db/src/helix_gateway/embedding_providers/mod.rs b/helix-db/src/helix_gateway/embedding_providers/mod.rs index 1f7dadaba..4554b5964 100644 --- a/helix-db/src/helix_gateway/embedding_providers/mod.rs +++ b/helix-db/src/helix_gateway/embedding_providers/mod.rs @@ -8,8 +8,8 @@ use url::Url; /// Trait for embedding models to fetch text embeddings. #[allow(async_fn_in_trait)] pub trait EmbeddingModel { - fn fetch_embedding(&self, text: &str) -> Result, GraphError>; - async fn fetch_embedding_async(&self, text: &str) -> Result, GraphError>; + fn fetch_embedding(&self, text: &str) -> Result, GraphError>; + async fn fetch_embedding_async(&self, text: &str) -> Result, GraphError>; } #[derive(Debug, Clone)] @@ -111,12 +111,12 @@ impl EmbeddingModelImpl { impl EmbeddingModel for EmbeddingModelImpl { /// Must be called with an active tokio context - fn fetch_embedding(&self, text: &str) -> Result, GraphError> { + fn fetch_embedding(&self, text: &str) -> Result, GraphError> { let handle = tokio::runtime::Handle::current(); handle.block_on(self.fetch_embedding_async(text)) } - async fn fetch_embedding_async(&self, text: &str) -> Result, GraphError> { + async fn fetch_embedding_async(&self, text: &str) -> Result, GraphError> { match &self.provider { EmbeddingProvider::OpenAI => { let api_key = self @@ -151,8 +151,9 @@ impl EmbeddingModel for EmbeddingModelImpl { .map(|v| { v.as_f64() .ok_or_else(|| GraphError::from("Invalid float value")) + .map(|f| f as f32) }) - .collect::, GraphError>>()?; + .collect::, GraphError>>()?; Ok(embedding) } @@ -198,8 +199,9 @@ impl EmbeddingModel for EmbeddingModelImpl { .map(|v| { v.as_f64() .ok_or_else(|| GraphError::from("Invalid float value")) + .map(|f| f as f32) }) - .collect::, GraphError>>()?; + .collect::, GraphError>>()?; Ok(embedding) } @@ -237,8 +239,9 @@ impl EmbeddingModel for EmbeddingModelImpl { .map(|v| { v.as_f64() .ok_or_else(|| GraphError::from("Invalid float value")) + .map(|f| f as f32) }) - .collect::, GraphError>>()?; + .collect::, GraphError>>()?; Ok(embedding) } diff --git a/helix-db/src/helix_gateway/mcp/mcp.rs b/helix-db/src/helix_gateway/mcp/mcp.rs index 269495d8c..be63687f2 100644 --- a/helix-db/src/helix_gateway/mcp/mcp.rs +++ b/helix-db/src/helix_gateway/mcp/mcp.rs @@ -1012,9 +1012,9 @@ pub fn search_vector_text(input: &mut MCPToolInput) -> Result, + pub vector: Vec, pub k: usize, - pub min_score: Option, + pub min_score: Option, } #[derive(Debug, Deserialize)] diff --git a/helix-db/src/helix_gateway/mcp/tools.rs b/helix-db/src/helix_gateway/mcp/tools.rs index d130a685c..9d5431c20 100644 --- a/helix-db/src/helix_gateway/mcp/tools.rs +++ b/helix-db/src/helix_gateway/mcp/tools.rs @@ -75,10 +75,10 @@ pub enum ToolArgs { k: usize, }, SearchVec { - vector: Vec, + vector: Vec, k: usize, - min_score: Option, cutoff: Option, + min_score: Option, }, } diff --git a/helix-db/src/protocol/custom_serde/compatibility_tests.rs b/helix-db/src/protocol/custom_serde/compatibility_tests.rs index 27fcd2557..303ed2d4a 100644 --- a/helix-db/src/protocol/custom_serde/compatibility_tests.rs +++ b/helix-db/src/protocol/custom_serde/compatibility_tests.rs @@ -355,8 +355,8 @@ mod compatibility_tests { let data = vec![1.0, 2.0]; // Different vector versions - let vec_v1 = create_arena_vector(&arena, id, "V1", 1, false, 0, &data, vec![]); - let vec_v2 = create_arena_vector(&arena, id, "V2", 2, false, 0, &data, vec![]); + let vec_v1 = create_arena_vector(&arena, id, "V1", 1, false, &data, vec![]); + let vec_v2 = create_arena_vector(&arena, id, "V2", 2, false, &data, vec![]); let props_v1 = bincode::serialize(&vec_v1).unwrap(); let props_v2 = bincode::serialize(&vec_v2).unwrap(); diff --git a/helix-db/src/protocol/custom_serde/edge_case_tests.rs b/helix-db/src/protocol/custom_serde/edge_case_tests.rs index 5428c6b9e..4a416b47e 100644 --- a/helix-db/src/protocol/custom_serde/edge_case_tests.rs +++ b/helix-db/src/protocol/custom_serde/edge_case_tests.rs @@ -79,7 +79,7 @@ mod edge_case_tests { }) .collect(); - let vector = create_arena_vector(&arena, id, "many_props", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "many_props", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -240,7 +240,7 @@ mod edge_case_tests { ("Ключ", Value::String("Значение".to_string())), ]; - let vector = create_arena_vector(&arena, id, "unicode", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "unicode", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -364,7 +364,7 @@ mod edge_case_tests { ); let props = vec![("complex", Value::Object(map))]; - let vector = create_arena_vector(&arena, id, "complex", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "complex", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -449,7 +449,7 @@ mod edge_case_tests { let data = vec![1.0]; let props = vec![("empty_obj", Value::Object(HashMap::new()))]; - let vector = create_arena_vector(&arena, id, "test", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "test", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -519,7 +519,7 @@ mod edge_case_tests { let id = 800800u128; // Subnormal (denormalized) numbers - let data = vec![f64::MIN_POSITIVE, f64::MIN_POSITIVE / 2.0, 1e-308, 1e-320]; + let data = vec![f32::MIN_POSITIVE, f32::MIN_POSITIVE / 2.0, 1e-308, 1e-320]; let vector = create_simple_vector(&arena, id, "subnormal", &data); let props_bytes = bincode::serialize(&vector).unwrap(); @@ -598,7 +598,7 @@ mod edge_case_tests { ("123", Value::String("one-two-three".to_string())), ]; - let vector = create_arena_vector(&arena, id, "numeric_keys", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "numeric_keys", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -663,7 +663,7 @@ mod edge_case_tests { ]); let props = vec![("mixed", mixed_array)]; - let vector = create_arena_vector(&arena, id, "test", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "test", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -680,7 +680,7 @@ mod edge_case_tests { fn test_vector_with_8192_dimensions() { let arena = Bump::new(); let id = 707707u128; - let data: Vec = (0..8192).map(|i| (i as f64) * 0.0001).collect(); + let data: Vec = (0..8192).map(|i| (i as f32) * 0.0001).collect(); let vector = create_simple_vector(&arena, id, "8k_dims", &data); let props_bytes = bincode::serialize(&vector).unwrap(); @@ -706,7 +706,7 @@ mod edge_case_tests { let result = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true); assert!(result.is_ok()); let deserialized = result.unwrap(); - assert!(deserialized.data(&arena).iter().all(|&v| v == 0.0)); + assert!(deserialized.data_borrowed().iter().all(|&v| v == 0.0)); } #[test] @@ -725,7 +725,7 @@ mod edge_case_tests { let deserialized = result.unwrap(); assert!( deserialized - .data(&arena) + .data_borrowed() .iter() .all(|&v| (v - 42.42).abs() < 1e-10) ); @@ -787,7 +787,7 @@ mod edge_case_tests { fn test_vector_max_complexity() { let arena = Bump::new(); let id = u128::MAX; - let data: Vec = (0..2048).map(|i| (i as f64).sin()).collect(); + let data: Vec = (0..2048).map(|i| (i as f32).sin()).collect(); let props: Vec<(&str, Value)> = (0..200) .map(|i| { @@ -796,8 +796,7 @@ mod edge_case_tests { }) .collect(); - let vector = - create_arena_vector(&arena, id, &"Vec".repeat(200), 255, true, 0, &data, props); + let vector = create_arena_vector(&arena, id, &"Vec".repeat(200), 255, true, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index e1bc3b411..603e8095b 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -228,7 +228,7 @@ mod error_handling_tests { fn test_vector_cast_empty_raw_data_panics() { let arena = Bump::new(); let empty_data: &[u8] = &[]; - HVector::cast_raw_vector_data(&arena, empty_data); + HVector::raw_vector_data_to_vec(empty_data, &arena); } #[test] @@ -236,7 +236,7 @@ mod error_handling_tests { let arena = Bump::new(); let id = 666777u128; let props = vec![("key", Value::String("value".to_string()))]; - let vector = create_arena_vector(&arena, id, "test", 1, false, 0, &[1.0], props); + let vector = create_arena_vector(&arena, id, "test", 1, false, &[1.0], props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -265,7 +265,7 @@ mod error_handling_tests { let arena = Bump::new(); // 7 bytes is not a multiple of 8 (size of f64) let misaligned: &[u8] = &[0, 1, 2, 3, 4, 5, 6]; - HVector::cast_raw_vector_data(&arena, misaligned); + HVector::raw_vector_data_to_vec(&misaligned, &arena); } #[test] @@ -470,7 +470,7 @@ mod error_handling_tests { let arena = Bump::new(); let id = 012012u128; - let vector = create_arena_vector(&arena, id, "test", 255, false, 0, &[1.0], vec![]); + let vector = create_arena_vector(&arena, id, "test", 255, false, &[1.0], vec![]); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -585,7 +585,7 @@ mod error_handling_tests { let id = 987654u128; // Vector with NaN, infinity, and other special values - let data = vec![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0, -0.0]; + let data = vec![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0, -0.0]; let vector = create_simple_vector(&arena, id, "special", &data); let props_bytes = bincode::serialize(&vector).unwrap(); diff --git a/helix-db/src/protocol/custom_serde/integration_tests.rs b/helix-db/src/protocol/custom_serde/integration_tests.rs index db6dd19af..33a6ae312 100644 --- a/helix-db/src/protocol/custom_serde/integration_tests.rs +++ b/helix-db/src/protocol/custom_serde/integration_tests.rs @@ -266,7 +266,7 @@ mod integration_tests { ("dimensions", Value::I32(3)), ]; - let vector = create_arena_vector(&arena, id, "doc_vector", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "doc_vector", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -328,7 +328,7 @@ mod integration_tests { let vectors: Vec = (0..15) .map(|i| { - let data = vec![i as f64, (i + 1) as f64, (i + 2) as f64]; + let data = vec![i as f32, (i + 1) as f32, (i + 2) as f32]; create_simple_vector(&arena, i as u128, &format!("vec_{}", i), &data) }) .collect(); diff --git a/helix-db/src/protocol/custom_serde/property_based_tests.rs b/helix-db/src/protocol/custom_serde/property_based_tests.rs index cbb1bbbc1..f88495241 100644 --- a/helix-db/src/protocol/custom_serde/property_based_tests.rs +++ b/helix-db/src/protocol/custom_serde/property_based_tests.rs @@ -54,9 +54,9 @@ mod property_based_tests { } // Strategy for generating vector data - fn arb_vector_data() -> impl Strategy> { + fn arb_vector_data() -> impl Strategy> { prop::collection::vec( - any::().prop_filter("Not NaN", |f| !f.is_nan()), + any::().prop_filter("Not NaN", |f| !f.is_nan()), 1..128, // 1 to 128 dimensions ) } @@ -298,8 +298,8 @@ mod property_based_tests { prop_assert_eq!(deserialized.len(), data.len()); // Check each data point (with floating point tolerance) - for (i, (&orig, &deser)) in data.iter().zip(deserialized.data(&arena).iter()).enumerate() { - let diff = (orig - deser).abs(); + for (i, (&orig, &deser)) in data.iter().zip(deserialized.data_borrowed().iter()).enumerate() { + let diff = (orig as f64 - deser as f64).abs(); prop_assert!(diff < 1e-10, "Data mismatch at index {}: {} vs {}", i, orig, deser); } } @@ -318,7 +318,7 @@ mod property_based_tests { .map(|(k, v)| (k.as_str(), v.clone())) .collect(); - let vector = create_arena_vector(&arena, id, &label, 1, deleted, 0, &data, props_refs); + let vector = create_arena_vector(&arena, id, &label, 1, deleted, &data, props_refs); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -344,12 +344,12 @@ mod property_based_tests { // Convert to bytes and back let bytes = create_vector_bytes(&data); - let restored = HVector::cast_raw_vector_data(&arena, &bytes); + let restored = HVector::raw_vector_data_to_vec( &bytes,&arena); prop_assert_eq!(restored.len(), data.len()); for (i, (&orig, &rest)) in data.iter().zip(restored.iter()).enumerate() { - let diff = (orig - rest).abs(); + let diff = (orig as f64 - rest as f64).abs(); prop_assert!(diff < 1e-10, "Data mismatch at index {}: {} vs {}", i, orig, rest); } } diff --git a/helix-db/src/protocol/custom_serde/test_utils.rs b/helix-db/src/protocol/custom_serde/test_utils.rs index 5da31ad64..0df86e710 100644 --- a/helix-db/src/protocol/custom_serde/test_utils.rs +++ b/helix-db/src/protocol/custom_serde/test_utils.rs @@ -227,12 +227,13 @@ pub fn create_arena_vector<'arena>( label: &str, version: u8, deleted: bool, - level: usize, - data: &[f64], + data: &[f32], props: Vec<(&str, Value)>, ) -> HVector<'arena> { let label_ref = arena.alloc_str(label); - let data_ref = arena.alloc_slice_copy(data); + + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + bump_vec.extend_from_slice(data); if props.is_empty() { HVector { @@ -240,9 +241,8 @@ pub fn create_arena_vector<'arena>( label: label_ref, version, deleted, - level, distance: None, - data: Some(Item::::from(data_ref, arena)), + data: Some(Item::::new(bump_vec)), properties: None, } } else { @@ -258,9 +258,8 @@ pub fn create_arena_vector<'arena>( label: label_ref, version, deleted, - level, distance: None, - data: Some(Item::::from(data_ref, arena)), + data: Some(Item::::new(bump_vec)), properties: Some(props_map), } } @@ -271,13 +270,13 @@ pub fn create_simple_vector<'arena>( arena: &'arena Bump, id: u128, label: &str, - data: &[f64], + data: &[f32], ) -> HVector<'arena> { - create_arena_vector(arena, id, label, 1, false, 0, data, vec![]) + create_arena_vector(arena, id, label, 1, false, data, vec![]) } /// Creates vector data as raw bytes -pub fn create_vector_bytes(data: &[f64]) -> Vec { +pub fn create_vector_bytes(data: &[f32]) -> Vec { bytemuck::cast_slice(data).to_vec() } diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index 0ab2ef6a4..ddbf9f90a 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -94,16 +94,15 @@ impl<'de, 'txn, 'arena> serde::de::DeserializeSeed<'de> for VectorDeSeed<'txn, ' .next_element_seed(OptionPropertiesMapDeSeed { arena: self.arena })? .ok_or_else(|| serde::de::Error::custom("Expected properties field"))?; - let data = HVector::cast_raw_vector_data(self.arena, self.raw_vector_data); + let data = HVector::raw_vector_data_to_vec(self.raw_vector_data, self.arena); Ok(HVector { id: self.id, label, deleted, version, - level: 0, distance: None, - data: Some(Item::::from(data, &self.arena)), + data: Some(Item::::new(data)), properties, }) } @@ -169,7 +168,6 @@ impl<'de, 'arena> serde::de::DeserializeSeed<'de> for VectoWithoutDataDeSeed<'ar label, version, deleted, - level: 0, properties, distance: None, data: None, diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index 8b35f1496..faca11621 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -51,7 +51,7 @@ mod vector_serialization_tests { let data = vec![0.5, -0.5, 1.5, -1.5]; let props = vec![("name", Value::String("test".to_string()))]; - let vector = create_arena_vector(&arena, id, "labeled_vector", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "labeled_vector", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -74,7 +74,7 @@ mod vector_serialization_tests { ("score", Value::F64(0.95)), ]; - let vector = create_arena_vector(&arena, id, "vector_label", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "vector_label", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -93,7 +93,7 @@ mod vector_serialization_tests { let data = vec![0.0; 128]; // Standard embedding dimension let props = all_value_types_props(); - let vector = create_arena_vector(&arena, id, "all_types", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "all_types", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -112,7 +112,7 @@ mod vector_serialization_tests { let data = vec![1.0, 2.0, 3.0]; let props = nested_value_props(); - let vector = create_arena_vector(&arena, id, "nested", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "nested", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -137,19 +137,19 @@ mod vector_serialization_tests { fn test_vector_data_to_bytes_128d() { let arena = Bump::new(); let id = 111111u128; - let data: Vec = (0..128).map(|i| i as f64 * 0.1).collect(); + let data: Vec = (0..128).map(|i| i as f32 * 0.1).collect(); let vector = create_simple_vector(&arena, id, "vector_128", &data); let bytes = vector.vector_data_to_bytes().unwrap(); - assert_eq!(bytes.len(), 128 * 8); // 128 dimensions * 8 bytes per f64 + assert_eq!(bytes.len(), 128 * 8); // 128 dimensions * 8 bytes per f32 } #[test] fn test_vector_data_to_bytes_384d() { let arena = Bump::new(); let id = 222222u128; - let data: Vec = (0..384).map(|i| i as f64 * 0.01).collect(); + let data: Vec = (0..384).map(|i| i as f32 * 0.01).collect(); let vector = create_simple_vector(&arena, id, "vector_384", &data); let bytes = vector.vector_data_to_bytes().unwrap(); @@ -161,7 +161,7 @@ mod vector_serialization_tests { fn test_vector_data_to_bytes_1536d() { let arena = Bump::new(); let id = 333333u128; - let data: Vec = (0..1536).map(|i| (i as f64).sin()).collect(); + let data: Vec = (0..1536).map(|i| (i as f32).sin()).collect(); let vector = create_simple_vector(&arena, id, "vector_1536", &data); let bytes = vector.vector_data_to_bytes().unwrap(); @@ -172,10 +172,10 @@ mod vector_serialization_tests { #[test] fn test_cast_raw_vector_data_128d() { let arena = Bump::new(); - let original_data: Vec = (0..128).map(|i| i as f64).collect(); + let original_data: Vec = (0..128).map(|i| i as f32).collect(); let raw_bytes = create_vector_bytes(&original_data); - let casted_data = HVector::cast_raw_vector_data(&arena, &raw_bytes); + let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena); assert_eq!(casted_data.len(), 128); for (i, &val) in casted_data.iter().enumerate() { @@ -189,7 +189,7 @@ mod vector_serialization_tests { let original_data = vec![3.14159, 2.71828, 1.41421, 1.73205]; let raw_bytes = create_vector_bytes(&original_data); - let casted_data = HVector::cast_raw_vector_data(&arena, &raw_bytes); + let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena); assert_eq!(casted_data.len(), original_data.len()); for (orig, casted) in original_data.iter().zip(casted_data.iter()) { @@ -212,7 +212,6 @@ mod vector_serialization_tests { assert_eq!(vector.len(), 4); assert_eq!(vector.version, 1); assert_eq!(vector.deleted, false); - assert_eq!(vector.level, 0); assert!(vector.properties.is_none()); } @@ -251,7 +250,7 @@ mod vector_serialization_tests { ("dimension", Value::I32(4)), ]; - let vector = create_arena_vector(&arena, id, "embedding", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "embedding", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -296,7 +295,7 @@ mod vector_serialization_tests { let id = 123456u128; let data = vec![1.0, 2.0]; - let vector = create_arena_vector(&arena, id, "versioned", 5, false, 0, &data, vec![]); + let vector = create_arena_vector(&arena, id, "versioned", 5, false, &data, vec![]); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -314,7 +313,7 @@ mod vector_serialization_tests { let id = 654321u128; let data = vec![0.0, 1.0]; - let vector = create_arena_vector(&arena, id, "deleted", 1, true, 0, &data, vec![]); + let vector = create_arena_vector(&arena, id, "deleted", 1, true, &data, vec![]); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -332,7 +331,7 @@ mod vector_serialization_tests { let id = 987654u128; let data = vec![1.0, 0.0]; - let vector = create_arena_vector(&arena, id, "active", 1, false, 0, &data, vec![]); + let vector = create_arena_vector(&arena, id, "active", 1, false, &data, vec![]); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -438,7 +437,7 @@ mod vector_serialization_tests { }) .collect(); - let vector = create_arena_vector(&arena, id, "many_props", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "many_props", 1, false, &data, props); let props_bytes = bincode::serialize(&vector).unwrap(); let data_bytes = vector.vector_data_to_bytes().unwrap(); @@ -477,7 +476,7 @@ mod vector_serialization_tests { fn test_vector_large_dimension_4096() { let arena = Bump::new(); let id = 951753u128; - let data: Vec = (0..4096).map(|i| i as f64 * 0.001).collect(); + let data: Vec = (0..4096).map(|i| i as f32 * 0.001).collect(); let vector = create_simple_vector(&arena, id, "4096d", &data); @@ -502,7 +501,7 @@ mod vector_serialization_tests { let data = vec![1.1, 2.2, 3.3]; let props = vec![("test", Value::String("value".to_string()))]; - let vector = create_arena_vector(&arena, id, "byte_test", 1, false, 0, &data, props); + let vector = create_arena_vector(&arena, id, "byte_test", 1, false, &data, props); // First roundtrip let props_bytes1 = bincode::serialize(&vector).unwrap(); From cf59e8d87bc58bee73f21d8e9a32a5542a76fd5f Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Wed, 19 Nov 2025 11:56:15 -0300 Subject: [PATCH 27/48] Make indexes thread-safe by saving it inside a RwLock --- helix-db/src/helix_engine/vector_core/mod.rs | 93 ++++++++++---------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index f91550137..e4982b344 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -1,4 +1,13 @@ -use std::{borrow::Cow, cmp::Ordering, hash::Hash}; +use std::{ + borrow::Cow, + cell::RefCell, + cmp::Ordering, + hash::Hash, + sync::{ + RwLock, + atomic::{self, AtomicU16, AtomicU32}, + }, +}; use bincode::Options; use byteorder::BE; @@ -257,18 +266,13 @@ pub struct VectorCore { pub vector_properties_db: Database, Bytes>, pub config: HNSWConfig, - /// Map labels to a different index - pub label_to_index: HashMap, + /// Map labels to a different index and dimension + pub label_to_index: RwLock>, /// Track the last index - curr_index: u16, - - pub id_map: HashMap, - curr_id: u32, + curr_index: AtomicU16, - /// The actual index is lazily build during the first access - // TODO: We should choose a better strategy to build the index - writer: Option>, - reader: Option>, + pub id_map: RwLock>, + curr_id: AtomicU32, } impl VectorCore { @@ -285,12 +289,10 @@ impl VectorCore { stats: VectorCoreStats { num_vectors: 0 }, vector_properties_db, config, - writer: None, - reader: None, - label_to_index: HashMap::new(), - curr_index: 0, - id_map: HashMap::new(), - curr_id: 0, + label_to_index: RwLock::new(HashMap::new()), + curr_index: AtomicU16::new(0), + id_map: RwLock::new(HashMap::new()), + curr_id: AtomicU32::new(0), }) } @@ -303,25 +305,39 @@ impl VectorCore { should_trickle: bool, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { - if self.reader.is_none() { - return Ok(bumpalo::collections::Vec::new_in(&arena)); - } todo!() } - fn get_or_create(&mut self, label: &str) -> u16 { - if let Some(&index) = self.label_to_index.get(label) { - index + /// Get a writer based on label. If it doesn't exist build a new index + /// and return a writer to it + fn get_writer_or_create_index( + &self, + label: &str, + dimension: usize, + txn: &mut RwTxn, + ) -> VectorCoreResult> { + if let Some(&(idx, dimension)) = self.label_to_index.read().unwrap().get(label) { + Ok(Writer::new(self.hsnw_index, idx, dimension)) } else { - self.curr_index += 1; + // Index do not exist, we should build it + let idx = self.curr_index.fetch_add(1, atomic::Ordering::SeqCst); self.label_to_index - .insert(label.to_string(), self.curr_index); - self.curr_index + .write() + .unwrap() + .insert(label.to_string(), (idx, dimension)); + let writer = Writer::new(self.hsnw_index, idx, dimension); + let mut rng = StdRng::from_os_rng(); + let mut builder = writer.builder(&mut rng); + + builder + .ef_construction(self.config.ef_construct) + .build(txn)?; + Ok(writer) } } pub fn insert<'arena>( - &mut self, + &self, txn: &mut RwTxn, label: &'arena str, data: &'arena [f32], @@ -329,31 +345,18 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { // index hasn't been built yet - if self.writer.is_none() { - let idx = self.get_or_create(label); - // assume the len of the first insertion as the - // index dimension - self.writer = Some(Writer::new(self.hsnw_index, idx, data.len())); - let mut rng = StdRng::from_os_rng(); - let mut builder = self.writer.as_ref().unwrap().builder(&mut rng); - builder - .ef_construction(self.config.ef_construct) - .build(txn)?; - } + let writer = self.get_writer_or_create_index(label, data.len(), txn)?; let mut bump_vec = bumpalo::collections::Vec::new_in(arena); bump_vec.extend_from_slice(data); let hvector = HVector::from_vec(label, bump_vec); - self.curr_id += 1; - self.id_map.insert(hvector.id, self.curr_id); + let idx = self.curr_id.fetch_add(1, atomic::Ordering::SeqCst); + self.id_map.write().unwrap().insert(hvector.id, idx); - self.writer - .as_ref() - .unwrap() - .add_item(txn, self.curr_id, data); + writer.add_item(txn, idx, data); - todo!() + Ok(hvector) } pub fn delete(&self, txn: &RwTxn, id: u128, arena: &bumpalo::Bump) -> VectorCoreResult<()> { From da81d6d6bbee1ce21818ccd582b7ef3f9bf97956 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Wed, 19 Nov 2025 14:03:33 -0300 Subject: [PATCH 28/48] Adjust spaces --- helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs | 7 ++++--- .../src/helix_engine/vector_core/spaces/simple_neon.rs | 7 ++++++- helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs | 7 +++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs index 720b1211e..a381b2ede 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs @@ -120,12 +120,13 @@ pub(crate) unsafe fn dot_similarity_avx( #[cfg(test)] mod tests { - use crate::helix_engine::vector_core::spaces::simple::*; + use super::*; + use crate::helix_engine::vector_core::spaces::simple::{ + dot_product_non_optimized, euclidean_distance_non_optimized, + }; #[test] fn test_spaces_avx() { - use super::*; - if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { let v1: Vec = vec![ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs index a176ce11e..1894fadd7 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs @@ -1,5 +1,5 @@ #[cfg(target_feature = "neon")] -use crate::unaligned_vector::UnalignedVector; +use crate::helix_engine::vector_core::unaligned_vector::UnalignedVector; use std::arch::aarch64::*; use std::ptr::read_unaligned; @@ -117,6 +117,11 @@ unsafe fn unaligned_float32x4_t(ptr: *const f32) -> float32x4_t { #[cfg(test)] mod tests { + use super::*; + use crate::helix_engine::vector_core::spaces::simple::{ + dot_product_non_optimized, euclidean_distance_non_optimized, + }; + #[cfg(target_feature = "neon")] #[test] fn test_spaces_neon() { diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs index 06ad0fa09..53705f6c6 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs @@ -115,13 +115,12 @@ pub(crate) unsafe fn dot_similarity_sse( #[cfg(test)] mod tests { - use crate::helix_engine::vector_core::spaces::simple::{ - dot_product_non_optimized, euclidean_distance_non_optimized, - }; - #[test] fn test_spaces_sse() { use super::*; + use crate::helix_engine::vector_core::spaces::simple::{ + dot_product_non_optimized, euclidean_distance_non_optimized, + }; if is_x86_feature_detected!("sse") { let v1: Vec = vec![ From 7ca245f1d18e37fbdbf0dfd30028d9b45599d740 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Wed, 19 Nov 2025 14:39:52 -0300 Subject: [PATCH 29/48] Implement hnsw search the idea is to delay the HVector creation as much as possible, since its creation is expensive --- helix-db/src/helix_engine/bm25/bm25.rs | 6 +- helix-db/src/helix_engine/bm25/bm25_tests.rs | 6 +- helix-db/src/helix_engine/storage_core/mod.rs | 6 +- .../hnsw_concurrent_tests.rs | 43 ++++------ helix-db/src/helix_engine/tests/hnsw_tests.rs | 2 +- .../traversal_core/ops/vectors/search.rs | 20 +++-- helix-db/src/helix_engine/vector_core/hnsw.rs | 1 - helix-db/src/helix_engine/vector_core/mod.rs | 84 +++++++++++++++---- helix-db/src/helix_engine/vector_core/node.rs | 1 - .../src/helix_engine/vector_core/reader.rs | 8 ++ 10 files changed, 114 insertions(+), 63 deletions(-) diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 7fa58543f..0793a604f 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -426,10 +426,8 @@ impl HybridSearch for HelixGraphStorage { false, &arena, )?; - let scores = results - .into_iter() - .map(|vec| (vec.id, vec.distance.unwrap_or(0.0))) - .collect::>(); + let scores = + results.into_global_id(&self.vectors.local_to_global_id.read().unwrap()); Ok(Some(scores)) }); diff --git a/helix-db/src/helix_engine/bm25/bm25_tests.rs b/helix-db/src/helix_engine/bm25/bm25_tests.rs index 1de78117c..74aa4a456 100644 --- a/helix-db/src/helix_engine/bm25/bm25_tests.rs +++ b/helix-db/src/helix_engine/bm25/bm25_tests.rs @@ -1431,7 +1431,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search() { - let (mut storage, _temp_dir) = setup_helix_storage(); + let (storage, _temp_dir) = setup_helix_storage(); let mut wtxn = storage.graph_env.write_txn().unwrap(); let docs = vec![ @@ -1474,7 +1474,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search_alpha_vectors() { - let (mut storage, _temp_dir) = setup_helix_storage(); + let (storage, _temp_dir) = setup_helix_storage(); // Insert some test documents first let mut wtxn = storage.graph_env.write_txn().unwrap(); @@ -1520,7 +1520,7 @@ mod tests { #[tokio::test] async fn test_hybrid_search_alpha_bm25() { - let (mut storage, _temp_dir) = setup_helix_storage(); + let (storage, _temp_dir) = setup_helix_storage(); // Insert some test documents first let mut wtxn = storage.graph_env.write_txn().unwrap(); diff --git a/helix-db/src/helix_engine/storage_core/mod.rs b/helix-db/src/helix_engine/storage_core/mod.rs index ba1229e22..d70b88888 100644 --- a/helix-db/src/helix_engine/storage_core/mod.rs +++ b/helix-db/src/helix_engine/storage_core/mod.rs @@ -462,7 +462,6 @@ impl StorageMethods for HelixGraphStorage { } fn drop_vector(&self, txn: &mut RwTxn, id: &u128) -> Result<(), GraphError> { - let arena = bumpalo::Bump::new(); let mut edges = HashSet::new(); let mut out_edges = HashSet::new(); let mut in_edges = HashSet::new(); @@ -499,9 +498,6 @@ impl StorageMethods for HelixGraphStorage { other_out_edges.push((from_node_id, label, edge_id)); } - // println!("In edges: {}", in_edges.len()); - - // println!("Deleting edges: {}", ); // Delete all related data for edge in edges { self.edges_db.delete(txn, Self::edge_key(&edge))?; @@ -531,7 +527,7 @@ impl StorageMethods for HelixGraphStorage { } // Delete vector data - self.vectors.delete(txn, *id, &arena)?; + self.vectors.delete(txn, *id)?; Ok(()) } diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs index 5eb75b7cd..64f64f711 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs @@ -98,7 +98,7 @@ fn test_concurrent_inserts_single_label() { let vector = random_vector(128); // Open the existing databases and insert - let mut index = open_vector_core(&env, &mut wtxn).unwrap(); + let index = open_vector_core(&env, &mut wtxn).unwrap(); index .insert( &mut wtxn, @@ -206,15 +206,11 @@ fn test_concurrent_searches_during_inserts() { match index.search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) { Ok(results) => { total_searches += 1; - total_results += results.len(); + total_results += results.nns.len(); // Validate result consistency - for (i, result) in results.iter().enumerate() { - assert!( - result.distance.is_some(), - "Result {} should have distance", - i - ); + for (i, &(_, distance)) in results.into_nns().iter().enumerate() { + assert!(distance > 0_f32, "Result {} should have distance", i); } } Err(e) => { @@ -285,7 +281,7 @@ fn test_concurrent_searches_during_inserts() { .search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) .unwrap(); assert!( - !results.is_empty(), + !results.nns.is_empty(), "Should find results after concurrent operations" ); } @@ -443,18 +439,18 @@ fn test_entry_point_consistency() { let results = search_result.unwrap(); assert!( - !results.is_empty(), + !results.nns.is_empty(), "Should return results if entry point is valid" ); // Verify results have valid properties - for result in results.iter() { - assert!(result.id > 0, "Result ID should be valid"); - assert!(!result.deleted, "Results should not be deleted"); - assert!( - !result.data_borrowed().is_empty(), - "Results should have data" - ); + for &(id, distance) in results.into_nns().iter() { + // assert!(result.id > 0, "Result ID should be valid"); + // assert!(!result.deleted, "Results should not be deleted"); + // assert!( + // !result.data_borrowed().is_empty(), + // "Results should have data" + // ); } } @@ -529,17 +525,14 @@ fn test_graph_connectivity_after_concurrent_inserts() { .unwrap(); assert!( - !results.is_empty(), + !results.nns.is_empty(), "Query {} should return results (graph should be connected)", i ); // All results should have valid distances - for result in results { - assert!( - result.distance.is_some() && result.distance.unwrap() >= 0.0, - "Result should have valid distance" - ); + for &(_, distance) in results.into_nns().iter() { + assert!(distance >= 0.0, "Result should have valid distance"); } } } @@ -557,7 +550,7 @@ fn test_transaction_isolation() { let initial_count = 10; { let mut txn = env.write_txn().unwrap(); - let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let arena = Bump::new(); for _ in 0..initial_count { @@ -591,7 +584,7 @@ fn test_transaction_isolation() { let handle = thread::spawn(move || { for _ in 0..20 { let mut wtxn = env_clone.write_txn().unwrap(); - let mut index = open_vector_core(&env_clone, &mut wtxn).unwrap(); + let index = open_vector_core(&env_clone, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(32); diff --git a/helix-db/src/helix_engine/tests/hnsw_tests.rs b/helix-db/src/helix_engine/tests/hnsw_tests.rs index 0c7c460fd..61c7c28a0 100644 --- a/helix-db/src/helix_engine/tests/hnsw_tests.rs +++ b/helix-db/src/helix_engine/tests/hnsw_tests.rs @@ -61,5 +61,5 @@ fn test_hnsw_search_returns_results() { let results = index .search(&txn, query.to_vec(), 5, "vector", false, &arena) .unwrap(); - assert!(!results.is_empty()); + assert!(!results.nns.is_empty()); } diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index 61f1ed896..8466d4abb 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -5,7 +5,7 @@ use crate::helix_engine::{ types::{GraphError, VectorError}, vector_core::HVector, }; -use std::iter::once; +use std::{iter::once, vec}; pub trait SearchVAdapter<'db, 'arena, 'txn>: Iterator, GraphError>> @@ -58,12 +58,18 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE ); let iter = match vectors { - Ok(vectors) => vectors - .into_iter() - // copying here! - .map(|vector| Ok::(TraversalValue::Vector(vector))) - .collect::>() - .into_iter(), + Ok(vectors) => { + let hvectors = + self.storage + .vectors + .nns_to_hvectors(vectors.into_nns(), false, self.arena); + + hvectors + .into_iter() + .map(|vector| Ok::(TraversalValue::Vector(vector))) + .collect::>() + .into_iter() + } Err(VectorError::VectorNotFound(id)) => { let error = GraphError::VectorError(format!("vector not found for id {id}")); once(Err(error)).collect::>().into_iter() diff --git a/helix-db/src/helix_engine/vector_core/hnsw.rs b/helix-db/src/helix_engine/vector_core/hnsw.rs index b9b64b0f5..defec54ca 100644 --- a/helix-db/src/helix_engine/vector_core/hnsw.rs +++ b/helix-db/src/helix_engine/vector_core/hnsw.rs @@ -12,7 +12,6 @@ use rand::distr::Distribution; use rand::distr::weighted::WeightedIndex; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use roaring::RoaringBitmap; -use tinyvec::{ArrayVec, array_vec}; use crate::helix_engine::vector_core::node::{Item, Node}; use crate::helix_engine::vector_core::{ diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index e4982b344..1e589910e 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -5,7 +5,7 @@ use std::{ hash::Hash, sync::{ RwLock, - atomic::{self, AtomicU16, AtomicU32}, + atomic::{self, AtomicU16, AtomicU32, AtomicUsize}, }, }; @@ -27,7 +27,7 @@ use crate::{ key::{Key, KeyCodec}, node::{Item, NodeCodec}, node_id::NodeMode, - reader::Reader, + reader::{Reader, Searched}, unaligned_vector::UnalignedVector, writer::Writer, }, @@ -254,8 +254,7 @@ impl HNSWConfig { } pub struct VectorCoreStats { - // Do it an atomic? - pub num_vectors: usize, + pub num_vectors: AtomicUsize, } // TODO: Properties filters @@ -271,7 +270,9 @@ pub struct VectorCore { /// Track the last index curr_index: AtomicU16, - pub id_map: RwLock>, + /// Maps global id (u128) to internal id (u32) and label + pub global_to_local_id: RwLock>, + pub local_to_global_id: RwLock>, curr_id: AtomicU32, } @@ -286,13 +287,16 @@ impl VectorCore { Ok(Self { hsnw_index: vectors_db, - stats: VectorCoreStats { num_vectors: 0 }, + stats: VectorCoreStats { + num_vectors: AtomicUsize::new(0), + }, vector_properties_db, config, label_to_index: RwLock::new(HashMap::new()), curr_index: AtomicU16::new(0), - id_map: RwLock::new(HashMap::new()), + global_to_local_id: RwLock::new(HashMap::new()), curr_id: AtomicU32::new(0), + local_to_global_id: RwLock::new(HashMap::new()), }) } @@ -302,10 +306,20 @@ impl VectorCore { query: Vec, k: usize, label: &'arena str, - should_trickle: bool, + _should_trickle: bool, arena: &'arena bumpalo::Bump, - ) -> VectorCoreResult>> { - todo!() + ) -> VectorCoreResult> { + match self.label_to_index.read().unwrap().get(label) { + Some(&(index, dimension)) => { + if dimension != query.len() { + return Err(VectorError::InvalidVectorLength); + } + + let reader = Reader::open(txn, index, self.hsnw_index)?; + reader.nns(k).by_vector(txn, query.as_slice(), arena) + } + None => Ok(Searched::new(bumpalo::vec![in &arena])), + } } /// Get a writer based on label. If it doesn't exist build a new index @@ -352,15 +366,53 @@ impl VectorCore { let hvector = HVector::from_vec(label, bump_vec); let idx = self.curr_id.fetch_add(1, atomic::Ordering::SeqCst); - self.id_map.write().unwrap().insert(hvector.id, idx); - - writer.add_item(txn, idx, data); + self.global_to_local_id + .write() + .unwrap() + .insert(hvector.id, (idx, label.to_string())); + self.local_to_global_id + .write() + .unwrap() + .insert(idx, hvector.id); + self.stats + .num_vectors + .fetch_add(1, atomic::Ordering::SeqCst); + + writer.add_item(txn, idx, data)?; Ok(hvector) } - pub fn delete(&self, txn: &RwTxn, id: u128, arena: &bumpalo::Bump) -> VectorCoreResult<()> { - Ok(()) + pub fn delete(&self, txn: &mut RwTxn, id: u128) -> VectorCoreResult<()> { + match self.global_to_local_id.read().unwrap().get(&id) { + Some(&(idx, ref label)) => { + let &(index, dimension) = self + .label_to_index + .read() + .unwrap() + .get(label) + .expect("if index exist label should also exist"); + let writer = Writer::new(self.hsnw_index, index, dimension); + writer.del_item(txn, idx)?; + self.stats + .num_vectors + .fetch_add(1, atomic::Ordering::SeqCst); + Ok(()) + } + None => Err(VectorError::VectorNotFound(format!( + "vector {} doesn't exist", + id + ))), + } + } + + pub fn nns_to_hvectors<'arena>( + &self, + nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>, + with_data: bool, + arena: &bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, HVector<'arena>> { + todo!() } pub fn get_full_vector<'arena>( @@ -382,6 +434,6 @@ impl VectorCore { } pub fn num_inserted_vectors(&self) -> usize { - self.stats.num_vectors + self.stats.num_vectors.load(atomic::Ordering::SeqCst) } } diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs index b0e6164b5..1c2de31d7 100644 --- a/helix-db/src/helix_engine/vector_core/node.rs +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -1,7 +1,6 @@ use core::fmt; use std::{borrow::Cow, ops::Deref}; -use bumpalo::collections::CollectIn; use bytemuck::{bytes_of, cast_slice, pod_read_unaligned}; use byteorder::{ByteOrder, NativeEndian}; use heed3::{BoxedError, BytesDecode, BytesEncode}; diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index dad45417a..6320a8dfd 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -4,6 +4,7 @@ use std::marker; use std::num::NonZeroUsize; use bumpalo::collections::CollectIn; +use hashbrown::HashMap; use heed3::RoTxn; use heed3::types::Bytes; use heed3::types::DecodeIgnore; @@ -58,6 +59,13 @@ impl<'arena> Searched<'arena> { pub fn into_nns(self) -> bumpalo::collections::Vec<'arena, (ItemId, f32)> { self.nns } + + pub fn into_global_id(&self, map: &HashMap) -> Vec<(u128, f32)> { + self.nns + .iter() + .map(|(item_id, score)| (*map.get(item_id).unwrap(), *score)) + .collect() + } } /// Options used to make a query against an hannoy [`Reader`]. From 55ac1199597586a8be5931289cf57b961da5bc34 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Fri, 21 Nov 2025 14:51:00 -0300 Subject: [PATCH 30/48] Implement nns_to_hvector --- .../traversal_tests/vector_traversal_tests.rs | 1 - .../traversal_core/ops/vectors/search.rs | 14 +++-- helix-db/src/helix_engine/vector_core/mod.rs | 63 +++++++++++++++++-- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs index a572dc4d3..2101c3abd 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs @@ -176,7 +176,6 @@ fn test_drop_vector_removes_edges() { .search_v::(&[0.5, 0.5, 0.5], 10, "vector", None) .collect::, _>>() .unwrap(); - drop(txn); let mut txn = storage.graph_env.write_txn().unwrap(); Drop::drop_traversal( diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index 8466d4abb..ce82d6b8f 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -25,7 +25,8 @@ pub trait SearchVAdapter<'db, 'arena, 'txn>: where F: Fn(&HVector, &RoTxn) -> bool, K: TryInto, - K::Error: std::fmt::Debug; + K::Error: std::fmt::Debug, + 'txn: 'arena; } impl<'db, 'arena, 'txn, I: Iterator, GraphError>>> @@ -47,6 +48,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE F: Fn(&HVector, &RoTxn) -> bool, K: TryInto, K::Error: std::fmt::Debug, + 'txn: 'arena, { let vectors = self.storage.vectors.search( self.txn, @@ -59,10 +61,12 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE let iter = match vectors { Ok(vectors) => { - let hvectors = - self.storage - .vectors - .nns_to_hvectors(vectors.into_nns(), false, self.arena); + let hvectors = self.storage.vectors.nns_to_hvectors( + self.txn, + vectors.into_nns(), + false, + self.arena, + ); hvectors .into_iter() diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 1e589910e..bf49a8a42 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -3,6 +3,7 @@ use std::{ cell::RefCell, cmp::Ordering, hash::Hash, + io::Read, sync::{ RwLock, atomic::{self, AtomicU16, AtomicU32, AtomicUsize}, @@ -27,7 +28,7 @@ use crate::{ key::{Key, KeyCodec}, node::{Item, NodeCodec}, node_id::NodeMode, - reader::{Reader, Searched}, + reader::{Reader, Searched, get_item}, unaligned_vector::UnalignedVector, writer::Writer, }, @@ -72,6 +73,7 @@ pub type CoreDatabase = heed3::Database>; pub struct HVector<'arena> { pub id: u128, pub distance: Option, + // TODO: String Interning pub label: &'arena str, pub deleted: bool, pub version: u8, @@ -88,8 +90,8 @@ impl<'arena> HVector<'arena> { let id = v6_uuid(); HVector { id, - version: 1, label, + version: 1, data: Some(Item::::new(data)), distance: None, properties: None, @@ -406,13 +408,62 @@ impl VectorCore { } } - pub fn nns_to_hvectors<'arena>( + pub fn nns_to_hvectors<'arena, 'txn>( &self, + txn: &'txn RoTxn, nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>, with_data: bool, - arena: &bumpalo::Bump, - ) -> bumpalo::collections::Vec<'arena, HVector<'arena>> { - todo!() + arena: &'arena bumpalo::Bump, + ) -> bumpalo::collections::Vec<'arena, HVector<'arena>> + where + 'txn: 'arena, + { + let mut results = bumpalo::collections::Vec::<'arena, HVector<'arena>>::with_capacity_in( + nns.len(), + arena, + ); + + let local_to_global_id = self.local_to_global_id.read().unwrap(); + let label_to_index = self.label_to_index.read().unwrap(); + let global_to_local_id = self.global_to_local_id.read().unwrap(); + + let (item_id, _) = nns.first().unwrap(); + let global_id = local_to_global_id.get(item_id).unwrap(); + let (_, label) = global_to_local_id.get(global_id).unwrap(); + let (index, _) = label_to_index.get(label).unwrap(); + let label = arena.alloc_str(&label); + + if with_data { + for (item_id, distance) in nns.into_iter() { + let global_id = local_to_global_id.get(&item_id).unwrap(); + + results.push(HVector { + id: *global_id, + distance: Some(distance), + label, + deleted: false, + version: 0, + properties: None, + data: get_item(self.hsnw_index, *index, txn, item_id).unwrap(), + }); + } + } else { + for (item_id, distance) in nns.into_iter() { + let global_id = local_to_global_id.get(&item_id).unwrap(); + + results.push(HVector { + id: *global_id, + distance: Some(distance), + label, + deleted: false, + version: 0, + properties: None, + data: None, + }); + } + } + + results } pub fn get_full_vector<'arena>( From d54a2368f94ebe05e01b55c65444e5a0e99866fa Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Fri, 21 Nov 2025 18:13:31 -0300 Subject: [PATCH 31/48] Properly implement get_full_vector Still room for improvements since it does two allocations for vector :/ --- helix-db/src/helix_engine/vector_core/mod.rs | 21 +++++++++++++++++-- helix-db/src/helix_engine/vector_core/node.rs | 13 ++++++++++++ .../src/helix_engine/vector_core/reader.rs | 6 +++--- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index bf49a8a42..6f73acac1 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -470,9 +470,26 @@ impl VectorCore { &self, txn: &RoTxn, id: u128, - arena: &bumpalo::Bump, + arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { - todo!() + let label_to_index = self.label_to_index.read().unwrap(); + let global_to_local_id = self.global_to_local_id.read().unwrap(); + + let (item_id, label) = global_to_local_id.get(&id).unwrap(); + let (index, _) = label_to_index.get(label).unwrap(); + let label = arena.alloc_str(&label); + + let item = get_item(self.hsnw_index, *index, txn, *item_id)?.map(|i| i.clone_in(arena)); + + Ok(HVector { + id: id, + distance: None, + label, + deleted: false, + version: 0, + properties: None, + data: item.clone(), + }) } pub fn get_vector_properties<'arena>( diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs index 1c2de31d7..20c3bd66d 100644 --- a/helix-db/src/helix_engine/vector_core/node.rs +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -76,6 +76,19 @@ impl Item<'_, D> { } } + /// Clones the item into the provided arena, returning a new Item + /// with a lifetime tied to the arena. + pub fn clone_in<'bump>(&self, arena: &'bump bumpalo::Bump) -> Item<'bump, D> { + // TODO: This does two allocations, we should do only one! + let vec_data = self.vector.to_vec(arena); + let vector = UnalignedVector::from_vec(vec_data); + + Item { + header: self.header, + vector, + } + } + /// Builds a new item from a `Vec`. pub fn new(vec: bumpalo::collections::Vec) -> Self { let vector = UnalignedVector::from_vec(vec); diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index 6320a8dfd..a6fdc287d 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -734,12 +734,12 @@ impl Reader { } } -pub fn get_item<'a, D: Distance>( +pub fn get_item<'txn, D: Distance>( database: CoreDatabase, index: u16, - rtxn: &'a RoTxn, + rtxn: &'txn RoTxn, item: ItemId, -) -> VectorCoreResult>> { +) -> VectorCoreResult>> { match database.get(rtxn, &Key::item(index, item))? { Some(Node::Item(item)) => Ok(Some(item)), Some(Node::Links(_)) => Ok(None), From 8b09c8ce8c5801461ce4aed24b90c65a420364ad Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Fri, 21 Nov 2025 21:29:22 -0300 Subject: [PATCH 32/48] Fix (some) clippy complaints --- helix-db/Cargo.toml | 6 +- helix-db/benches/bm25_benches.rs | 8 +- helix-db/benches/hnsw_benches.rs | 127 ++---------------- helix-db/src/helix_engine/bm25/bm25.rs | 2 +- .../hnsw_concurrent_tests.rs | 12 +- .../concurrency_tests/hnsw_loom_tests.rs | 4 +- .../traversal_concurrent_tests.rs | 2 +- helix-db/src/helix_engine/tests/hnsw_tests.rs | 4 +- .../tests/traversal_tests/count_tests.rs | 6 +- .../traversal_tests/node_traversal_tests.rs | 16 +-- .../traversal_core/ops/bm25/search_bm25.rs | 2 +- .../traversal_core/ops/source/v_from_type.rs | 10 +- .../ops/vectors/brute_force_search.rs | 10 +- .../traversal_core/ops/vectors/search.rs | 4 +- .../traversal_core/traversal_iter.rs | 12 +- .../src/helix_engine/vector_core/item_iter.rs | 4 +- .../src/helix_engine/vector_core/metadata.rs | 2 +- helix-db/src/helix_engine/vector_core/mod.rs | 48 ++++--- helix-db/src/helix_engine/vector_core/node.rs | 4 + .../src/helix_engine/vector_core/reader.rs | 17 ++- .../vector_core/spaces/simple_avx.rs | 8 +- .../vector_core/spaces/simple_sse.rs | 8 +- .../src/helix_engine/vector_core/stats.rs | 8 +- .../vector_core/unaligned_vector/f32.rs | 2 +- .../vector_core/unaligned_vector/mod.rs | 2 +- .../src/helix_engine/vector_core/writer.rs | 16 +-- .../builtin/all_nodes_and_edges.rs | 70 ++++++---- .../src/helix_gateway/builtin/node_by_id.rs | 8 +- .../helix_gateway/builtin/node_connections.rs | 1 - .../helix_gateway/builtin/nodes_by_label.rs | 6 +- .../src/helix_gateway/tests/gateway_tests.rs | 2 +- helix-db/src/helix_gateway/tests/mcp_tests.rs | 6 +- .../custom_serde/compatibility_tests.rs | 4 +- .../protocol/custom_serde/edge_case_tests.rs | 2 +- .../custom_serde/error_handling_tests.rs | 2 +- .../src/protocol/custom_serde/test_utils.rs | 2 + helix-db/src/protocol/custom_serde/tests.rs | 16 +-- .../src/protocol/custom_serde/vector_serde.rs | 2 + .../custom_serde/vector_serde_tests.rs | 6 +- helix-db/src/protocol/request.rs | 2 +- helix-db/src/protocol/value.rs | 8 +- helix-db/src/utils/id.rs | 4 +- helix-db/src/utils/label_hash.rs | 6 +- metrics/src/lib.rs | 2 +- 44 files changed, 207 insertions(+), 286 deletions(-) diff --git a/helix-db/Cargo.toml b/helix-db/Cargo.toml index 0957d8e53..37907ae11 100644 --- a/helix-db/Cargo.toml +++ b/helix-db/Cargo.toml @@ -52,11 +52,7 @@ tracing = "0.1.41" core_affinity = "0.8.3" async-trait = "0.1.88" thiserror = "2.0.12" -polars = { version = "0.46.0", features = [ - "parquet", - "lazy", - "json", -], optional = true } +polars = { version = "0.46.0", features = ["parquet", "lazy", "json"], optional = true } subtle = "2.6.1" sha_256 = "=0.1.1" byteorder = "1.5.0" diff --git a/helix-db/benches/bm25_benches.rs b/helix-db/benches/bm25_benches.rs index 2d39dfa9d..d6b86ad64 100644 --- a/helix-db/benches/bm25_benches.rs +++ b/helix-db/benches/bm25_benches.rs @@ -51,7 +51,7 @@ mod tests { let mut rng = rand::rng(); let mut docs = vec![]; - let relevant_count = 4000 as usize; + let relevant_count = 4000_usize; let total_docs = 1_000_000; for i in tqdm::new( @@ -126,7 +126,7 @@ mod tests { let id = v6_uuid(); let doc_lower = doc.to_lowercase(); - let _ = bm25.insert_doc(&mut wtxn, id, &doc_lower).unwrap(); + bm25.insert_doc(&mut wtxn, id, &doc_lower).unwrap(); for term in &query_terms { if doc_lower.contains(term) { @@ -139,7 +139,7 @@ mod tests { for query_term in query_terms { let rtxn = bm25.graph_env.read_txn().unwrap(); - let term_count = query_term_counts.get(query_term).unwrap().clone(); + let term_count = *query_term_counts.get(query_term).unwrap(); let results = bm25.search(&rtxn, query_term, limit).unwrap(); @@ -148,7 +148,7 @@ mod tests { debug_println!("term count: {}, results len: {}", term_count, results.len()); assert!( - precision >= 0.9 && precision <= 1.0, + (0.9..=1.0).contains(&precision), "precision {} below 0.9 or above 1.0", precision ); diff --git a/helix-db/benches/hnsw_benches.rs b/helix-db/benches/hnsw_benches.rs index 62667b7a4..c274ffb88 100644 --- a/helix-db/benches/hnsw_benches.rs +++ b/helix-db/benches/hnsw_benches.rs @@ -3,11 +3,7 @@ mod tests { use heed3::{Env, EnvOpenOptions, RoTxn}; use helix_db::{ - helix_engine::vector_core::{ - hnsw::HNSW, - unaligned_vector::HVector, - vector_core::{HNSWConfig, VectorCore}, - }, + helix_engine::vector_core::{HNSWConfig, HVector, VectorCore}, utils::tqdm::tqdm, }; use polars::prelude::*; @@ -57,12 +53,14 @@ mod tests { /// Returns query ids and their associated closest k vectors (by vec id) fn calc_ground_truths( base_vectors: Vec, - query_vectors: &Vec<(usize, Vec)>, + query_vectors: &Vec<(usize, Vec)>, k: usize, ) -> HashMap> { let base_vectors = Arc::new(base_vectors); let results = Arc::new(Mutex::new(HashMap::new())); let chunk_size = (query_vectors.len() + num_cpus::get() - 1) / num_cpus::get(); + let arena = bumpalo::Bump::new(); + let label = arena.alloc_str("test"); let handles: Vec<_> = query_vectors .chunks(chunk_size) @@ -75,14 +73,16 @@ mod tests { let local_results: HashMap> = chunk .into_iter() .map(|(query_id, query_vec)| { - let query_hvector = HVector::from_slice(0, query_vec); + let mut vecs = bumpalo::collections::Vec::new_in(&arena); + vecs.extend_from_slice(query_vec.as_slice()); + let query_hvector = HVector::from_vec(&label, vecs); let mut distances: Vec<(u128, f64)> = base_vectors .iter() .filter_map(|base_vec| { query_hvector .distance_to(base_vec) - .map(|dist| (base_vec.id.clone(), dist)) + .map(|dist| (base_vec.id, dist)) .ok() }) .collect(); @@ -179,109 +179,6 @@ mod tests { vectors } - /* - #[test] - fn bench_hnsw_search_short() { - //fetch_parquet_vectors().unwrap(); - let n_base = 4_000; - let dims = 950; - let vectors = gen_sim_vecs(n_base, dims, 0.8); - - let n_query = 400; - let mut rng = rand::rng(); - let mut shuffled_vectors = vectors.clone(); - shuffled_vectors.shuffle(&mut rng); - let base_vectors = &shuffled_vectors[..n_base - n_query]; - let query_vectors = &shuffled_vectors[n_base - n_query..]; - - println!("num of base vecs: {}", base_vectors.len()); - println!("num of query vecs: {}", query_vectors.len()); - - let k = 10; - - let env = setup_temp_env(); - let mut txn = env.write_txn().unwrap(); - - let mut total_insertion_time = std::time::Duration::from_secs(0); - let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); - - let mut all_vectors: Vec = Vec::new(); - let over_all_time = Instant::now(); - for (i, data) in vectors.iter().enumerate() { - let start_time = Instant::now(); - let vec = index.insert::(&mut txn, &data, None).unwrap(); - let time = start_time.elapsed(); - all_vectors.push(vec); - if i % 1000 == 0 { - println!("{} => inserting in {} ms", i, time.as_millis()); - println!("time taken so far: {:?}", over_all_time.elapsed()); - } - total_insertion_time += time; - } - txn.commit().unwrap(); - - let txn = env.read_txn().unwrap(); - println!("{:?}", index.config); - - println!( - "total insertion time: {:.2?} seconds", - total_insertion_time.as_secs_f64() - ); - println!( - "average insertion time per vec: {:.2?} milliseconds", - total_insertion_time.as_millis() as f64 / n_base as f64 - ); - - println!("calculating ground truths"); - let ground_truths = calc_ground_truths(all_vectors, query_vectors.to_vec(), k); - - println!("searching and comparing..."); - let test_id = format!("k = {} with {} queries", k, n_query); - - let mut total_recall = 0.0; - let mut total_precision = 0.0; - let mut total_search_time = std::time::Duration::from_secs(0); - for ((_, query), gt) in query_vectors.iter().zip(ground_truths.iter()) { - let start_time = Instant::now(); - let results = index.search::(&txn, query, k, None, false).unwrap(); - let search_duration = start_time.elapsed(); - total_search_time += search_duration; - - let result_indices: HashSet = results - .into_iter() - .map(|hvector| hvector.get_id().to_string()) - .collect(); - - let gt_indices: HashSet = gt.iter().cloned().collect(); - //println!("gt: {:?}\nresults: {:?}\n", gt_indices, result_indices); - let true_positives = result_indices.intersection(>_indices).count(); - - let recall: f64 = true_positives as f64 / gt_indices.len() as f64; - let precision: f64 = true_positives as f64 / result_indices.len() as f64; - - total_recall += recall; - total_precision += precision; - } - - println!( - "total search time: {:.2?} seconds", - total_search_time.as_secs_f64() - ); - println!( - "average search time per query: {:.2?} milliseconds", - total_search_time.as_millis() as f64 / n_query as f64 - ); - - total_recall = total_recall / n_query as f64; - total_precision = total_precision / n_query as f64; - println!( - "{}: avg. recall: {:.4?}, avg. precision: {:.4?}", - test_id, total_recall, total_precision - ); - assert!(total_recall >= 0.8, "recall not high enough!"); - } - */ - /// Test the precision of the HNSW search algorithm #[test] fn bench_hnsw_search_long() { @@ -289,6 +186,8 @@ mod tests { let n_query = 1000; // 10-20% let k = 10; let mut vectors = load_dbpedia_vectors(n_base).unwrap(); + let arena = bumpalo::Bump::new(); + let label = arena.alloc_str("test"); let mut rng = rand::rng(); vectors.shuffle(&mut rng); @@ -299,7 +198,7 @@ mod tests { .iter() .enumerate() .map(|(i, x)| (i + 1, x.clone())) - .collect::)>>(); + .collect::)>>(); println!("num of base vecs: {}", base_vectors.len()); println!("num of query vecs: {}", query_vectors.len()); @@ -313,7 +212,7 @@ mod tests { let over_all_time = Instant::now(); for (i, data) in base_vectors.iter().enumerate() { let start_time = Instant::now(); - let vec = index.insert::(&mut txn, &data, None).unwrap(); + let vec = index.insert(&mut txn, label, &data, None, &arena).unwrap(); let time = start_time.elapsed(); base_all_vectors.push(vec); //println!("{} => inserting in {} ms", i, time.as_millis()); @@ -349,7 +248,7 @@ mod tests { for (qid, query) in query_vectors.iter() { let start_time = Instant::now(); let results = index - .search::(&txn, query, k, "vector", None, false) + .search(&txn, query, k, "vector", false, &arena) .unwrap(); let search_duration = start_time.elapsed(); total_search_time += search_duration; diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 0793a604f..f3145a216 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -445,7 +445,7 @@ impl HybridSearch for HelixGraphStorage { // correct_score = alpha * bm25_score + (1.0 - alpha) * vector_score if let Some(vector_results) = vector_results? { for (doc_id, score) in vector_results { - let similarity = (1.0 / (1.0 + score)) as f32; + let similarity = 1.0 / (1.0 + score) ; combined_scores .entry(doc_id) .and_modify(|existing_score| *existing_score += (1.0 - alpha) * similarity) diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs index 64f64f711..9a05cb434 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs @@ -161,7 +161,7 @@ fn test_concurrent_searches_during_inserts() { // Initialize with some initial vectors { let mut txn = env.write_txn().unwrap(); - let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let arena = Bump::new(); for _ in 0..50 { @@ -246,7 +246,7 @@ fn test_concurrent_searches_during_inserts() { let vector = random_vector(128); let data = arena.alloc_slice_copy(&vector); - let mut index = open_vector_core(&env, &mut wtxn).unwrap(); + let index = open_vector_core(&env, &mut wtxn).unwrap(); index .insert(&mut wtxn, "search_test", data, None, &arena) .expect("Insert should succeed"); @@ -318,7 +318,7 @@ fn test_concurrent_inserts_multiple_labels() { for i in 0..vectors_per_label { let mut wtxn = env.write_txn().unwrap(); - let mut index = open_vector_core(&env, &mut wtxn).unwrap(); + let index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(64); @@ -403,7 +403,7 @@ fn test_entry_point_consistency() { for _ in 0..vectors_per_thread { let mut wtxn = env.write_txn().unwrap(); - let mut index = open_vector_core(&env, &mut wtxn).unwrap(); + let index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(32); @@ -444,7 +444,7 @@ fn test_entry_point_consistency() { ); // Verify results have valid properties - for &(id, distance) in results.into_nns().iter() { + for &(_id, _distance) in results.into_nns().iter() { // assert!(result.id > 0, "Result ID should be valid"); // assert!(!result.deleted, "Results should not be deleted"); // assert!( @@ -484,7 +484,7 @@ fn test_graph_connectivity_after_concurrent_inserts() { for _ in 0..vectors_per_thread { let mut wtxn = env.write_txn().unwrap(); - let mut index = open_vector_core(&env, &mut wtxn).unwrap(); + let index = open_vector_core(&env, &mut wtxn).unwrap(); let arena = Bump::new(); let vector = random_vector(64); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs index 8ba2b4259..709c0b9f5 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs @@ -137,7 +137,7 @@ fn loom_neighbor_count_race() { // This test demonstrates the lost update problem // In real code, this should use fetch_add assert!( - final_count >= 1 && final_count <= 2, + (1..=2).contains(&final_count), "Expected 1 or 2, got {}", final_count ); @@ -176,7 +176,7 @@ fn loom_max_level_update_race() { // Should end up with max level of 3 let final_max = max_level.load(Ordering::SeqCst); assert!( - final_max >= 2 && final_max <= 3, + (2..=3).contains(&final_max), "Expected 2 or 3, got {}", final_max ); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs index 932874b79..03ece0217 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs @@ -654,7 +654,7 @@ fn test_concurrent_graph_topology_consistency() { // Verify all edges point to valid nodes for result in storage.edges_db.iter(&rtxn).unwrap() { let (edge_id, edge_bytes) = result.unwrap(); - let edge = crate::utils::items::Edge::from_bincode_bytes(edge_id, &edge_bytes, &arena).unwrap(); + let edge = crate::utils::items::Edge::from_bincode_bytes(edge_id, edge_bytes, &arena).unwrap(); // Verify source exists assert!( diff --git a/helix-db/src/helix_engine/tests/hnsw_tests.rs b/helix-db/src/helix_engine/tests/hnsw_tests.rs index 61c7c28a0..8563cff88 100644 --- a/helix-db/src/helix_engine/tests/hnsw_tests.rs +++ b/helix-db/src/helix_engine/tests/hnsw_tests.rs @@ -25,7 +25,7 @@ fn setup_env() -> (Env, TempDir) { fn test_hnsw_insert_and_count() { let (env, _temp_dir) = setup_env(); let mut txn = env.write_txn().unwrap(); - let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let vector: Vec = (0..4).map(|_| rand::rng().random_range(0.0..1.0)).collect(); for _ in 0..10 { @@ -43,7 +43,7 @@ fn test_hnsw_insert_and_count() { fn test_hnsw_search_returns_results() { let (env, _temp_dir) = setup_env(); let mut txn = env.write_txn().unwrap(); - let mut index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); + let index = VectorCore::new(&env, &mut txn, HNSWConfig::new(None, None, None)).unwrap(); let mut rng = rand::rng(); for _ in 0..128 { diff --git a/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs index 8e03ee32e..f895da874 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs @@ -180,16 +180,16 @@ fn test_count_filter_ref() { .filter_ref(|val, txn| { if let Ok(val) = val { let val_id = val.id(); - Ok(G::new(&storage, &txn, &arena) + Ok(G::new(&storage, txn, &arena) .n_from_id(&val_id) .out_node("Country_to_City") .count_to_val() .map_value_or(false, |v| { println!( "v: {v:?}, res: {:?}", - *v > 10.clone() + *v > 10 ); - *v > 10.clone() + *v > 10 })?) } else { Ok(false) diff --git a/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs index 5b6de6daf..72ff80068 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs @@ -368,7 +368,7 @@ fn test_double_add_and_double_fetch() { let arena = Bump::new(); let mut txn = db.graph_env.write_txn().unwrap(); - let original_node1 = G::new_mut(&db, &arena, &mut txn) + let original_node1 = G::new_mut(db, &arena, &mut txn) .add_n( "person", props_option(&arena, props! { "entity_name" => "person1" }), @@ -376,7 +376,7 @@ fn test_double_add_and_double_fetch() { ) .collect_to_obj().unwrap(); - let original_node2 = G::new_mut(&db, &arena, &mut txn) + let original_node2 = G::new_mut(db, &arena, &mut txn) .add_n( "person", props_option(&arena, props! { "entity_name" => "person2" }), @@ -387,26 +387,26 @@ fn test_double_add_and_double_fetch() { txn.commit().unwrap(); let mut txn = db.graph_env.write_txn().unwrap(); - let node1 = G::new(&db, &txn, &arena) + let node1 = G::new(db, &txn, &arena) .n_from_type("person") .filter_ref(|val, _| { if let Ok(val) = val { Ok(val .get_property("entity_name") - .map_or(false, |v| *v == "person1")) + .is_some_and(|v| *v == "person1")) } else { Ok(false) } }) .collect::,_>>().unwrap(); - let node2 = G::new(&db, &txn, &arena) + let node2 = G::new(db, &txn, &arena) .n_from_type("person") .filter_ref(|val, _| { if let Ok(val) = val { Ok(val .get_property("entity_name") - .map_or(false, |v| *v == "person2")) + .is_some_and(|v| *v == "person2")) } else { Ok(false) } @@ -418,7 +418,7 @@ fn test_double_add_and_double_fetch() { assert_eq!(node2.len(), 1); assert_eq!(node2[0].id(), original_node2.id()); - let _e = G::new_mut(&db, &arena, &mut txn) + let _e = G::new_mut(db, &arena, &mut txn) .add_edge( "knows", None, @@ -431,7 +431,7 @@ fn test_double_add_and_double_fetch() { txn.commit().unwrap(); let txn = db.graph_env.read_txn().unwrap(); - let e = G::new(&db, &txn, &arena) + let e = G::new(db, &txn, &arena) .e_from_type("knows") .collect::,_>>().unwrap(); assert_eq!(e.len(), 1); diff --git a/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs b/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs index 8933a4fd5..2a1728e14 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/bm25/search_bm25.rs @@ -82,7 +82,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE if label_in_lmdb == label_as_bytes { match Node::<'arena>::from_bincode_bytes(id, value, self.arena) { Ok(node) => { - return Some(Ok(TraversalValue::NodeWithScore { node, score: score })); + return Some(Ok(TraversalValue::NodeWithScore { node, score })); } Err(e) => { println!("{} Error decoding node: {:?}", line!(), e); diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs index cae3bbff7..b013028d4 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs @@ -1,9 +1,9 @@ use crate::helix_engine::{ traversal_core::{ - LMDB_STRING_HEADER_LENGTH, traversal_iter::RoTraversalIterator, + traversal_iter::RoTraversalIterator, traversal_value::TraversalValue, }, - types::{GraphError, VectorError}, + types::GraphError, }; pub trait VFromTypeAdapter<'db, 'arena, 'txn>: @@ -31,21 +31,21 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE fn v_from_type( self, label: &'arena str, - get_vector_data: bool, + _get_vector_data: bool, ) -> RoTraversalIterator< 'db, 'arena, 'txn, impl Iterator, GraphError>>, > { - let label_bytes = label.as_bytes(); + let _label_bytes = label.as_bytes(); let iter = self .storage .vectors .vector_properties_db .iter(self.txn) .unwrap() - .filter_map(move |item| todo!()); + .filter_map(move |_item| todo!()); RoTraversalIterator { storage: self.storage, diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs index 670cc6c24..2ddc20f14 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs @@ -31,7 +31,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE { fn brute_force_search_v( self, - query: &'arena [f32], + _query: &'arena [f32], k: K, ) -> RoTraversalIterator< 'db, @@ -43,12 +43,12 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE K: TryInto, K::Error: std::fmt::Debug, { - let arena = bumpalo::Bump::new(); + let _arena = bumpalo::Bump::new(); let iter = self .inner .filter_map(|v| match v { Ok(TraversalValue::Vector(mut v)) => { - let mut bump_vec = bumpalo::collections::Vec::new_in(&self.arena); + let mut bump_vec = bumpalo::collections::Vec::new_in(self.arena); bump_vec.extend_from_slice(v.data_borrowed()); let d = @@ -60,13 +60,13 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE }) .sorted_by(|v1, v2| v1.partial_cmp(v2).unwrap()) .take(k.try_into().unwrap()) - .filter_map(move |mut item| { + .filter_map(move |item| { match self .storage .vectors .get_vector_properties(self.txn, item.id, self.arena) { - Ok(Some(vector_without_data)) => { + Ok(Some(_vector_without_data)) => { // todo! // item.expand_from_vector_without_data(vector_without_data); Some(item) diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index ce82d6b8f..c81be45ad 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -5,7 +5,7 @@ use crate::helix_engine::{ types::{GraphError, VectorError}, vector_core::HVector, }; -use std::{iter::once, vec}; +use std::iter::once; pub trait SearchVAdapter<'db, 'arena, 'txn>: Iterator, GraphError>> @@ -37,7 +37,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE query: &'arena [f32], k: K, label: &'arena str, - filter: Option<&'arena [F]>, + _filter: Option<&'arena [F]>, ) -> RoTraversalIterator< 'db, 'arena, diff --git a/helix-db/src/helix_engine/traversal_core/traversal_iter.rs b/helix-db/src/helix_engine/traversal_core/traversal_iter.rs index 302062394..5cab22e6f 100644 --- a/helix-db/src/helix_engine/traversal_core/traversal_iter.rs +++ b/helix-db/src/helix_engine/traversal_core/traversal_iter.rs @@ -64,15 +64,15 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE default: bool, f: impl Fn(&Value) -> bool, ) -> Result { - let val = match &self.inner.next() { + + match &self.inner.next() { Some(Ok(TraversalValue::Value(val))) => Ok(f(val)), Some(Ok(_)) => Err(GraphError::ConversionError( "Expected value, got something else".to_string(), )), Some(Err(err)) => Err(GraphError::from(err.to_string())), None => Ok(default), - }; - val + } } } @@ -138,14 +138,14 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE default: bool, f: impl Fn(&Value) -> bool, ) -> Result { - let val = match &self.inner.next() { + + match &self.inner.next() { Some(Ok(TraversalValue::Value(val))) => Ok(f(val)), Some(Ok(_)) => Err(GraphError::ConversionError( "Expected value, got something else".to_string(), )), Some(Err(err)) => Err(GraphError::from(err.to_string())), None => Ok(default), - }; - val + } } } diff --git a/helix-db/src/helix_engine/vector_core/item_iter.rs b/helix-db/src/helix_engine/vector_core/item_iter.rs index e6fccebe0..b0c82ae7d 100644 --- a/helix-db/src/helix_engine/vector_core/item_iter.rs +++ b/helix-db/src/helix_engine/vector_core/item_iter.rs @@ -40,7 +40,7 @@ impl<'t, D: Distance> Iterator for ItemIter<'t, D> { match self.inner.next() { Some(Ok((key, node))) => match node { Node::Item(Item { header: _, vector }) => { - let mut vector = vector.to_vec(&self.arena); + let mut vector = vector.to_vec(self.arena); if vector.len() != self.dimensions { // quantized codecs pad to 8-bytes so we truncate to recover len vector.truncate(self.dimensions); @@ -49,7 +49,7 @@ impl<'t, D: Distance> Iterator for ItemIter<'t, D> { } Node::Links(_) => unreachable!("Node must not be a link"), }, - Some(Err(e)) => Some(Err(e.into())), + Some(Err(e)) => Some(Err(e)), None => None, } } diff --git a/helix-db/src/helix_engine/vector_core/metadata.rs b/helix-db/src/helix_engine/vector_core/metadata.rs index a0c21645a..0d251ba5c 100644 --- a/helix-db/src/helix_engine/vector_core/metadata.rs +++ b/helix-db/src/helix_engine/vector_core/metadata.rs @@ -28,7 +28,7 @@ impl<'a> heed3::BytesEncode<'a> for MetadataCodec { distance, max_level, } = item; - debug_assert!(!distance.as_bytes().iter().any(|&b| b == 0)); + debug_assert!(!distance.as_bytes().contains(&0)); let mut output = Vec::with_capacity( size_of::() diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 6f73acac1..55738c60d 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -1,9 +1,5 @@ use std::{ - borrow::Cow, - cell::RefCell, cmp::Ordering, - hash::Hash, - io::Read, sync::{ RwLock, atomic::{self, AtomicU16, AtomicU32, AtomicUsize}, @@ -24,12 +20,10 @@ use crate::{ helix_engine::{ types::VectorError, vector_core::{ - distance::{Cosine, Distance, DistanceValue}, - key::{Key, KeyCodec}, + distance::Cosine, + key::KeyCodec, node::{Item, NodeCodec}, - node_id::NodeMode, reader::{Reader, Searched, get_item}, - unaligned_vector::UnalignedVector, writer::Writer, }, }, @@ -76,6 +70,7 @@ pub struct HVector<'arena> { // TODO: String Interning pub label: &'arena str, pub deleted: bool, + pub level: Option, pub version: u8, pub properties: Option>, pub data: Option>, @@ -96,6 +91,7 @@ impl<'arena> HVector<'arena> { distance: None, properties: None, deleted: false, + level: None, } } @@ -157,7 +153,7 @@ impl<'arena> HVector<'arena> { bincode::serialize(self) } - pub fn distance_to(&self, rhs: &HVector<'arena>) -> VectorCoreResult { + pub fn distance_to(&self, _rhs: &HVector<'arena>) -> VectorCoreResult { todo!() } @@ -192,10 +188,10 @@ impl<'arena> HVector<'arena> { } pub fn from_raw_vector_data<'txn>( - arena: &'arena bumpalo::Bump, - raw_vector_data: &'txn [u8], - label: &'arena str, - id: u128, + _arena: &'arena bumpalo::Bump, + _raw_vector_data: &'txn [u8], + _label: &'arena str, + _id: u128, ) -> Result { todo!() } @@ -357,7 +353,7 @@ impl VectorCore { txn: &mut RwTxn, label: &'arena str, data: &'arena [f32], - properties: Option>, + _properties: Option>, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { // index hasn't been built yet @@ -431,7 +427,7 @@ impl VectorCore { let global_id = local_to_global_id.get(item_id).unwrap(); let (_, label) = global_to_local_id.get(global_id).unwrap(); let (index, _) = label_to_index.get(label).unwrap(); - let label = arena.alloc_str(&label); + let label = arena.alloc_str(label); if with_data { for (item_id, distance) in nns.into_iter() { @@ -442,6 +438,7 @@ impl VectorCore { distance: Some(distance), label, deleted: false, + level: None, version: 0, properties: None, data: get_item(self.hsnw_index, *index, txn, item_id).unwrap(), @@ -458,6 +455,7 @@ impl VectorCore { deleted: false, version: 0, properties: None, + level: None, data: None, }); } @@ -477,16 +475,17 @@ impl VectorCore { let (item_id, label) = global_to_local_id.get(&id).unwrap(); let (index, _) = label_to_index.get(label).unwrap(); - let label = arena.alloc_str(&label); + let label = arena.alloc_str(label); let item = get_item(self.hsnw_index, *index, txn, *item_id)?.map(|i| i.clone_in(arena)); Ok(HVector { - id: id, + id, distance: None, label, deleted: false, version: 0, + level: None, properties: None, data: item.clone(), }) @@ -494,9 +493,9 @@ impl VectorCore { pub fn get_vector_properties<'arena>( &self, - txn: &RoTxn, - id: u128, - arena: &bumpalo::Bump, + _txn: &RoTxn, + _id: u128, + _arena: &bumpalo::Bump, ) -> VectorCoreResult>> { todo!() } @@ -504,4 +503,13 @@ impl VectorCore { pub fn num_inserted_vectors(&self) -> usize { self.stats.num_vectors.load(atomic::Ordering::SeqCst) } + + pub fn get_all_vectors<'arena>( + &self, + _txn: &RoTxn, + _level: Option, + _arena: &bumpalo::Bump, + ) -> VectorCoreResult>> { + todo!() + } } diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs index 20c3bd66d..21d0bc698 100644 --- a/helix-db/src/helix_engine/vector_core/node.rs +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -131,6 +131,10 @@ impl<'a> ItemIds<'a> { self.bytes.len() / size_of::() } + pub fn is_empty(&self) -> bool { + self.bytes.is_empty() + } + pub fn iter(&self) -> impl Iterator + 'a { self.bytes .chunks_exact(size_of::()) diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index a6fdc287d..70867848b 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -86,7 +86,7 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { let res = self .reader .nns_by_item(rtxn, item, self, arena)? - .map(|res| Searched::new(res)); + .map(Searched::new); Ok(res) } @@ -222,11 +222,10 @@ impl<'a> Visitor<'a> { // candidates bitmap, but the final result must *not* include them. if res.len() < self.ef || dist < f_max { search_queue.push((Reverse(OrderedFloat(dist)), point)); - if let Some(c) = self.candidates { - if !c.contains(point) { + if let Some(c) = self.candidates + && !c.contains(point) { continue; } - } if res.len() == self.ef { let _ = res.push_pop_max((OrderedFloat(dist), point)); } else { @@ -481,7 +480,7 @@ impl Reader { ) -> VectorCoreResult>> { Ok( get_item(self.database, self.index, rtxn, item_id)?.map(|item| { - let mut vec = item.vector.to_vec(&arena); + let mut vec = item.vector.to_vec(arena); vec.truncate(self.dimensions()); vec }), @@ -535,16 +534,16 @@ impl Reader { .candidates .is_some_and(|c| self.item_ids().is_disjoint(c)) { - return Ok(bumpalo::collections::Vec::new_in(&arena)); + return Ok(bumpalo::collections::Vec::new_in(arena)); } // If the number of candidates is less than a given threshold, perform linear search if let Some(candidates) = opt.candidates.filter(|c| c.len() < LINEAR_SEARCH_THRESHOLD) { - return self.brute_force_search(query, rtxn, candidates, opt.count, &arena); + return self.brute_force_search(query, rtxn, candidates, opt.count, arena); } // exhaustive search - self.hnsw_search(query, rtxn, opt, &arena) + self.hnsw_search(query, rtxn, opt, arena) } /// Directly retrieves items in the candidate list and ranks them by distance to the query. @@ -557,7 +556,7 @@ impl Reader { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { let mut item_distances = - bumpalo::collections::Vec::with_capacity_in(candidates.len() as usize, &arena); + bumpalo::collections::Vec::with_capacity_in(candidates.len() as usize, arena); for item_id in candidates { let Some(vector) = self.item_vector(rtxn, item_id, arena)? else { diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs index a381b2ede..1efc3edca 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs @@ -17,7 +17,7 @@ unsafe fn hsum256_ps_avx(x: __m256) -> f32 { pub(crate) unsafe fn euclid_similarity_avx( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { +) -> f32 { unsafe { // It is safe to load unaligned floats from a pointer. // @@ -62,14 +62,14 @@ pub(crate) unsafe fn euclid_similarity_avx( result += (a - b).powi(2); } result -} +}} #[target_feature(enable = "avx")] #[target_feature(enable = "fma")] pub(crate) unsafe fn dot_similarity_avx( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { +) -> f32 { unsafe { // It is safe to load unaligned floats from a pointer. // @@ -116,7 +116,7 @@ pub(crate) unsafe fn dot_similarity_avx( result += a * b; } result -} +}} #[cfg(test)] mod tests { diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs index 53705f6c6..ec018afa4 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs @@ -17,7 +17,7 @@ unsafe fn hsum128_ps_sse(x: __m128) -> f32 { pub(crate) unsafe fn euclid_similarity_sse( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { +) -> f32 { unsafe { // It is safe to load unaligned floats from a pointer. // @@ -58,13 +58,13 @@ pub(crate) unsafe fn euclid_similarity_sse( result += (a - b).powi(2); } result -} +}} #[target_feature(enable = "sse")] pub(crate) unsafe fn dot_similarity_sse( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { +) -> f32 { unsafe { // It is safe to load unaligned floats from a pointer. // @@ -111,7 +111,7 @@ pub(crate) unsafe fn dot_similarity_sse( result += a * b; } result -} +}} #[cfg(test)] mod tests { diff --git a/helix-db/src/helix_engine/vector_core/stats.rs b/helix-db/src/helix_engine/vector_core/stats.rs index ef652ed98..e88e72a2c 100644 --- a/helix-db/src/helix_engine/vector_core/stats.rs +++ b/helix-db/src/helix_engine/vector_core/stats.rs @@ -11,7 +11,7 @@ use crate::helix_engine::vector_core::node::{Links, Node}; // TODO: ignore the phantom #[derive(Debug)] -pub(crate) struct BuildStats { +pub struct BuildStats { /// a counter to see how many times `HnswBuilder.add_link` is invoked pub n_links_added: AtomicUsize, /// a counter tracking how many times we hit lmdb @@ -26,6 +26,12 @@ pub(crate) struct BuildStats { _phantom: PhantomData, } +impl Default for BuildStats { + fn default() -> Self { + Self::new() + } +} + impl BuildStats { pub fn new() -> BuildStats { BuildStats { diff --git a/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs b/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs index 4ffbbe910..98ed2ee7d 100644 --- a/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs +++ b/helix-db/src/helix_engine/vector_core/unaligned_vector/f32.rs @@ -46,7 +46,7 @@ impl VectorCodec for f32 { arena: &'arena bumpalo::Bump, ) -> bumpalo::collections::Vec<'arena, f32> { let iter = vec.iter(); - let mut ret = bumpalo::collections::Vec::with_capacity_in(iter.len(), &arena); + let mut ret = bumpalo::collections::Vec::with_capacity_in(iter.len(), arena); ret.extend(iter); ret } diff --git a/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs b/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs index 55688983a..48d77fc99 100644 --- a/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs +++ b/helix-db/src/helix_engine/vector_core/unaligned_vector/mod.rs @@ -95,7 +95,7 @@ impl UnalignedVector { &self, arena: &'arena bumpalo::Bump, ) -> bumpalo::collections::Vec<'arena, f32> { - Codec::to_vec(self, &arena) + Codec::to_vec(self, arena) } /// Returns the len of the vector in terms of elements. diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs index a09bc4986..6b4bc5f75 100644 --- a/helix-db/src/helix_engine/vector_core/writer.rs +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -26,12 +26,12 @@ pub struct VectorBuilder<'a, D: Distance, R: Rng + SeedableRng> { inner: BuildOption, } -pub(crate) struct BuildOption { - pub(crate) ef_construction: usize, - pub(crate) alpha: f32, - pub(crate) available_memory: Option, - pub(crate) m: usize, - pub(crate) m_max_0: usize, +pub struct BuildOption { + pub ef_construction: usize, + pub alpha: f32, + pub available_memory: Option, + pub m: usize, + pub m_max_0: usize, } impl BuildOption { @@ -299,12 +299,11 @@ impl Writer { // Fetches the item's ids, not the links. fn item_indices(&self, wtxn: &mut RwTxn) -> VectorCoreResult { let mut indices = RoaringBitmap::new(); - for (_, result) in self + for result in self .database .remap_types::() .prefix_iter(wtxn, &Prefix::item(self.index))? .remap_key_type::() - .enumerate() { let (i, _) = result?; indices.insert(i.node.unwrap_item()); @@ -363,6 +362,7 @@ impl<'a, D: Distance> FrozenReader<'a, D> { } /// Clears all the links. Starts from the last node and stops at the first item. +#[allow(dead_code)] fn clear_links( wtxn: &mut RwTxn, database: CoreDatabase, diff --git a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs index 9a8c618b0..3215a62d8 100644 --- a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs +++ b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs @@ -108,7 +108,7 @@ pub fn nodes_edges_inner(input: HandlerInput) -> Result (HelixGraphEngine, TempDir) { let temp_dir = TempDir::new().unwrap(); @@ -239,7 +236,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_edges_inner(input); @@ -263,10 +259,12 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props1 = vec![("name", Value::String("Alice".to_string()))]; + let props1 = [("name", Value::String("Alice".to_string()))]; let props_map1 = ImmutablePropertiesMap::new( props1.len(), - props1.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props1 + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -274,10 +272,12 @@ mod tests { .add_n(arena.alloc_str("person"), Some(props_map1), None) .collect_to_obj()?; - let props2 = vec![("name", Value::String("Bob".to_string()))]; + let props2 = [("name", Value::String("Bob".to_string()))]; let props_map2 = ImmutablePropertiesMap::new( props2.len(), - props2.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props2 + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -286,7 +286,13 @@ mod tests { .collect_to_obj()?; let _edge = G::new_mut(&engine.storage, &arena, &mut txn) - .add_edge(arena.alloc_str("knows"), None, node1.id(), node2.id(), false) + .add_edge( + arena.alloc_str("knows"), + None, + node1.id(), + node2.id(), + false, + ) .collect_to_obj()?; txn.commit().unwrap(); @@ -303,7 +309,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_edges_inner(input); @@ -327,10 +332,12 @@ mod tests { let mut nodes = Vec::new(); for i in 0..10 { - let props = vec![("index", Value::I64(i))]; + let props = [("index", Value::I64(i))]; let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -343,7 +350,13 @@ mod tests { // Add some edges to satisfy the nodes_edges_to_json method for i in 0..5 { let _edge = G::new_mut(&engine.storage, &arena, &mut txn) - .add_edge(arena.alloc_str("connects"), None, nodes[i].id(), nodes[i+1].id(), false) + .add_edge( + arena.alloc_str("connects"), + None, + nodes[i].id(), + nodes[i + 1].id(), + false, + ) .collect_to_obj()?; } @@ -362,7 +375,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_edges_inner(input); @@ -382,10 +394,12 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props = vec![("name", Value::String("Test".to_string()))]; + let props = [("name", Value::String("Test".to_string()))]; let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -408,7 +422,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_edges_inner(input); @@ -431,7 +444,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_edges_inner(input); diff --git a/helix-db/src/helix_gateway/builtin/node_by_id.rs b/helix-db/src/helix_gateway/builtin/node_by_id.rs index 8731a6c31..3cd06f42c 100644 --- a/helix-db/src/helix_gateway/builtin/node_by_id.rs +++ b/helix-db/src/helix_gateway/builtin/node_by_id.rs @@ -168,7 +168,7 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props = vec![("name", Value::String("Alice".to_string()))]; + let props = [("name", Value::String("Alice".to_string()))]; let props_map = ImmutablePropertiesMap::new( props.len(), props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), @@ -294,10 +294,8 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props = vec![ - ("name", Value::String("Alice".to_string())), - ("age", Value::I64(30)), - ]; + let props = [("name", Value::String("Alice".to_string())), + ("age", Value::I64(30))]; let props_map = ImmutablePropertiesMap::new( props.len(), props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), diff --git a/helix-db/src/helix_gateway/builtin/node_connections.rs b/helix-db/src/helix_gateway/builtin/node_connections.rs index 61df0ed79..5937b3791 100644 --- a/helix-db/src/helix_gateway/builtin/node_connections.rs +++ b/helix-db/src/helix_gateway/builtin/node_connections.rs @@ -229,7 +229,6 @@ mod tests { protocol::{request::Request, request::RequestType, Format}, helix_gateway::router::router::HandlerInput, utils::id::ID, - helixc::generator::traversal_steps::EdgeType, }; fn setup_test_engine() -> (HelixGraphEngine, TempDir) { diff --git a/helix-db/src/helix_gateway/builtin/nodes_by_label.rs b/helix-db/src/helix_gateway/builtin/nodes_by_label.rs index 61e0a39d0..820eba754 100644 --- a/helix-db/src/helix_gateway/builtin/nodes_by_label.rs +++ b/helix-db/src/helix_gateway/builtin/nodes_by_label.rs @@ -176,7 +176,7 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props1 = vec![("name", Value::String("Alice".to_string()))]; + let props1 = [("name", Value::String("Alice".to_string()))]; let props_map1 = ImmutablePropertiesMap::new( props1.len(), props1.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), @@ -187,7 +187,7 @@ mod tests { .add_n(arena.alloc_str("person"), Some(props_map1), None) .collect_to_obj()?; - let props2 = vec![("name", Value::String("Bob".to_string()))]; + let props2 = [("name", Value::String("Bob".to_string()))]; let props_map2 = ImmutablePropertiesMap::new( props2.len(), props2.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), @@ -235,7 +235,7 @@ mod tests { let arena = bumpalo::Bump::new(); for i in 0..10 { - let props = vec![("index", Value::I64(i))]; + let props = [("index", Value::I64(i))]; let props_map = ImmutablePropertiesMap::new( props.len(), props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), diff --git a/helix-db/src/helix_gateway/tests/gateway_tests.rs b/helix-db/src/helix_gateway/tests/gateway_tests.rs index fd0aa359c..ddfbec850 100644 --- a/helix-db/src/helix_gateway/tests/gateway_tests.rs +++ b/helix-db/src/helix_gateway/tests/gateway_tests.rs @@ -336,7 +336,7 @@ fn test_gateway_opts_default_workers_per_core() { #[cfg(feature = "api-key")] mod api_key_tests { - use super::*; + use crate::helix_gateway::key_verification::verify_key; use crate::protocol::{HelixError, request::Request}; use axum::body::Bytes; diff --git a/helix-db/src/helix_gateway/tests/mcp_tests.rs b/helix-db/src/helix_gateway/tests/mcp_tests.rs index 0907475f8..1ca9e7bf5 100644 --- a/helix-db/src/helix_gateway/tests/mcp_tests.rs +++ b/helix-db/src/helix_gateway/tests/mcp_tests.rs @@ -340,7 +340,7 @@ mod mcp_tests { let response = out_step(&mut input).unwrap(); let body = String::from_utf8(response.body.clone()).unwrap(); - assert!(body.contains(&uuid_str(person2.id(), &arena))); + assert!(body.contains(uuid_str(person2.id(), &arena))); } #[test] @@ -906,7 +906,7 @@ mod mcp_tests { .unwrap(); let results = stream.collect().unwrap(); - assert!(results.len() > 0); + assert!(!results.is_empty()); } #[test] @@ -1141,7 +1141,7 @@ mod mcp_tests { .unwrap(); let results = stream.collect().unwrap(); - assert!(results.len() > 0); + assert!(!results.is_empty()); } #[test] diff --git a/helix-db/src/protocol/custom_serde/compatibility_tests.rs b/helix-db/src/protocol/custom_serde/compatibility_tests.rs index 303ed2d4a..7ccc4ba59 100644 --- a/helix-db/src/protocol/custom_serde/compatibility_tests.rs +++ b/helix-db/src/protocol/custom_serde/compatibility_tests.rs @@ -258,7 +258,7 @@ mod compatibility_tests { assert_eq!(restored.id, id); assert_eq!(restored.label, "LegacyVector"); assert_eq!(restored.version, 1); - assert_eq!(restored.deleted, false); + assert!(!restored.deleted); } #[test] @@ -274,7 +274,7 @@ mod compatibility_tests { let new_vector = HVector::from_bincode_bytes(&arena, Some(&old_bytes), &data_bytes, id, true).unwrap(); - assert_eq!(new_vector.deleted, true); + assert!(new_vector.deleted); } #[test] diff --git a/helix-db/src/protocol/custom_serde/edge_case_tests.rs b/helix-db/src/protocol/custom_serde/edge_case_tests.rs index 4a416b47e..3bd17ad1c 100644 --- a/helix-db/src/protocol/custom_serde/edge_case_tests.rs +++ b/helix-db/src/protocol/custom_serde/edge_case_tests.rs @@ -616,7 +616,7 @@ mod edge_case_tests { let arena = Bump::new(); let id = 404404u128; - let large_array = Value::Array((0..1000).map(|i| Value::I32(i)).collect()); + let large_array = Value::Array((0..1000).map(Value::I32).collect()); let props = vec![("big_array", large_array)]; let node = create_arena_node(&arena, id, "test", 0, props); diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index 603e8095b..40dfeab3e 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -265,7 +265,7 @@ mod error_handling_tests { let arena = Bump::new(); // 7 bytes is not a multiple of 8 (size of f64) let misaligned: &[u8] = &[0, 1, 2, 3, 4, 5, 6]; - HVector::raw_vector_data_to_vec(&misaligned, &arena); + HVector::raw_vector_data_to_vec(misaligned, &arena); } #[test] diff --git a/helix-db/src/protocol/custom_serde/test_utils.rs b/helix-db/src/protocol/custom_serde/test_utils.rs index 0df86e710..a9218aaec 100644 --- a/helix-db/src/protocol/custom_serde/test_utils.rs +++ b/helix-db/src/protocol/custom_serde/test_utils.rs @@ -244,6 +244,7 @@ pub fn create_arena_vector<'arena>( distance: None, data: Some(Item::::new(bump_vec)), properties: None, + level: None, } } else { let len = props.len(); @@ -261,6 +262,7 @@ pub fn create_arena_vector<'arena>( distance: None, data: Some(Item::::new(bump_vec)), properties: Some(props_map), + level: None, } } } diff --git a/helix-db/src/protocol/custom_serde/tests.rs b/helix-db/src/protocol/custom_serde/tests.rs index 50eb59d99..e04154428 100644 --- a/helix-db/src/protocol/custom_serde/tests.rs +++ b/helix-db/src/protocol/custom_serde/tests.rs @@ -386,7 +386,7 @@ mod node_serialization_tests { // Check that both have the same keys and values (regardless of order) for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).expect(&format!("Missing key: {}", key)); + let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); // For nested objects, we need to compare recursively since HashMap order may differ assert!(values_equal(old_value, new_value), "Value mismatch for key {}: {:?} != {:?}", key, old_value, new_value); } @@ -413,7 +413,7 @@ mod node_serialization_tests { a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y)) } (Value::Object(a), Value::Object(b)) => { - a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).map_or(false, |bv| values_equal(v, bv))) + a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) } (Value::Date(a), Value::Date(b)) => a == b, (Value::Id(a), Value::Id(b)) => a == b, @@ -602,14 +602,12 @@ mod node_serialization_tests { fn test_node_serialization_utf8_labels() { let arena = Bump::new(); - let utf8_labels = vec![ - "Hello", + let utf8_labels = ["Hello", "世界", "🚀🌟", "Привет", "مرحبا", - "Ñoño", - ]; + "Ñoño"]; for (idx, label) in utf8_labels.iter().enumerate() { let id = idx as u128; @@ -843,7 +841,7 @@ mod edge_serialization_tests { a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y)) } (Value::Object(a), Value::Object(b)) => { - a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).map_or(false, |bv| values_equal(v, bv))) + a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) } (Value::Date(a), Value::Date(b)) => a == b, (Value::Id(a), Value::Id(b)) => a == b, @@ -938,7 +936,7 @@ mod edge_serialization_tests { // Check semantic equality (order may differ) for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).expect(&format!("Missing key: {}", key)); + let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); assert!(values_equal(old_value, new_value), "Value mismatch for key {}", key); } } @@ -1045,7 +1043,7 @@ mod edge_serialization_tests { // Compare nested values for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).expect(&format!("Missing key: {}", key)); + let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); assert!(values_equal(old_value, new_value), "Value mismatch for key {}: {:?} != {:?}", key, old_value, new_value); } } diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index ddbf9f90a..42121a958 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -104,6 +104,7 @@ impl<'de, 'txn, 'arena> serde::de::DeserializeSeed<'de> for VectorDeSeed<'txn, ' distance: None, data: Some(Item::::new(data)), properties, + level: None, }) } } @@ -170,6 +171,7 @@ impl<'de, 'arena> serde::de::DeserializeSeed<'de> for VectoWithoutDataDeSeed<'ar deleted, properties, distance: None, + level: None, data: None, }) } diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index faca11621..36fe41c98 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -211,7 +211,7 @@ mod vector_serialization_tests { assert_eq!(vector.label, label); assert_eq!(vector.len(), 4); assert_eq!(vector.version, 1); - assert_eq!(vector.deleted, false); + assert!(!vector.deleted); assert!(vector.properties.is_none()); } @@ -322,7 +322,7 @@ mod vector_serialization_tests { let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); - assert_eq!(deserialized.deleted, true); + assert!(deserialized.deleted); } #[test] @@ -340,7 +340,7 @@ mod vector_serialization_tests { let deserialized = HVector::from_bincode_bytes(&arena2, Some(&props_bytes), data_bytes, id, true).unwrap(); - assert_eq!(deserialized.deleted, false); + assert!(!deserialized.deleted); } // ======================================================================== diff --git a/helix-db/src/protocol/request.rs b/helix-db/src/protocol/request.rs index 3fe3c314d..c0d84f3ce 100644 --- a/helix-db/src/protocol/request.rs +++ b/helix-db/src/protocol/request.rs @@ -203,7 +203,7 @@ mod tests { #[test] fn test_request_type_clone() { let rt1 = RequestType::MCP; - let rt2 = rt1.clone(); + let rt2 = rt1; assert!(matches!(rt1, RequestType::MCP)); assert!(matches!(rt2, RequestType::MCP)); diff --git a/helix-db/src/protocol/value.rs b/helix-db/src/protocol/value.rs index cdf2cac6e..c1830ac22 100644 --- a/helix-db/src/protocol/value.rs +++ b/helix-db/src/protocol/value.rs @@ -1730,8 +1730,8 @@ mod tests { assert_eq!(Value::F64(1.0), Value::F64(1.0)); assert_eq!(Value::I64(1), Value::U64(1)); assert_eq!(Value::U64(1), Value::I64(1)); - assert_eq!(Value::I32(1), 1 as i32); - assert_eq!(Value::U32(1), 1 as i32); + assert_eq!(Value::I32(1), 1_i32); + assert_eq!(Value::U32(1), 1_i32); } #[test] @@ -1991,7 +1991,7 @@ mod tests { let val = Value::Boolean(true); let b: bool = val.into(); - assert_eq!(b, true); + assert!(b); let val = Value::String("test".to_string()); let s: String = val.into(); @@ -2068,7 +2068,7 @@ mod tests { let val = Value::Boolean(true); let b: &bool = val.into_primitive(); - assert_eq!(*b, true); + assert!(*b); let val = Value::String("test".to_string()); let s = val.as_str(); diff --git a/helix-db/src/utils/id.rs b/helix-db/src/utils/id.rs index 4500911d8..3bdfcf578 100644 --- a/helix-db/src/utils/id.rs +++ b/helix-db/src/utils/id.rs @@ -274,7 +274,7 @@ mod tests { let id = ID::from(value); // Test Deref trait - let deref_value: &u128 = &*id; + let deref_value: &u128 = &id; assert_eq!(*deref_value, value); } @@ -301,7 +301,7 @@ mod tests { #[test] fn test_id_ordering() { - let mut ids = vec![ID::from(300u128), ID::from(100u128), ID::from(200u128)]; + let mut ids = [ID::from(300u128), ID::from(100u128), ID::from(200u128)]; ids.sort(); diff --git a/helix-db/src/utils/label_hash.rs b/helix-db/src/utils/label_hash.rs index 8bd42f71f..bacc79820 100644 --- a/helix-db/src/utils/label_hash.rs +++ b/helix-db/src/utils/label_hash.rs @@ -149,13 +149,11 @@ mod tests { #[test] fn test_hash_label_similar_strings() { // Test labels that differ by only one character - let labels = vec![ - "person", + let labels = ["person", "persons", "person1", "person_", - "Person", - ]; + "Person"]; let hashes: Vec<[u8; 4]> = labels.iter() .map(|l| hash_label(l, None)) diff --git a/metrics/src/lib.rs b/metrics/src/lib.rs index ad6039f47..48dadd54f 100644 --- a/metrics/src/lib.rs +++ b/metrics/src/lib.rs @@ -684,7 +684,7 @@ mod tests { // Should be able to serialize batch let json_bytes = sonic_rs::to_vec(&events).unwrap(); - assert!(json_bytes.len() > 0); + assert!(!json_bytes.is_empty()); // Should be valid JSON array let json_str = String::from_utf8(json_bytes).unwrap(); From 38529d83e5bdc1ce6cdebdc4d5374c454e89c50c Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 00:15:47 -0300 Subject: [PATCH 33/48] Implement v_from_type and improve iterator implementation --- .../custom_serde/property_based_tests.txt | 8 + .../traversal_core/ops/source/v_from_type.rs | 29 ++-- .../src/helix_engine/vector_core/item_iter.rs | 28 ++-- helix-db/src/helix_engine/vector_core/key.rs | 2 +- helix-db/src/helix_engine/vector_core/mod.rs | 145 ++++++++++++++++-- .../src/helix_engine/vector_core/writer.rs | 7 +- .../builtin/all_nodes_and_edges.rs | 2 +- hql-tests/run.sh | 0 8 files changed, 175 insertions(+), 46 deletions(-) create mode 100644 helix-db/proptest-regressions/protocol/custom_serde/property_based_tests.txt mode change 100644 => 100755 hql-tests/run.sh diff --git a/helix-db/proptest-regressions/protocol/custom_serde/property_based_tests.txt b/helix-db/proptest-regressions/protocol/custom_serde/property_based_tests.txt new file mode 100644 index 000000000..671e9f480 --- /dev/null +++ b/helix-db/proptest-regressions/protocol/custom_serde/property_based_tests.txt @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 77500f2f8b3836370e7b9cac8ca607972d1c177022f9d547f49358b339b2f680 # shrinks to id = 0, label = "A" +cc 6dd62308707837f8c6c3890ad62075b4cb7db3bac79adfc33c941c0f10501116 # shrinks to id = 7, label = "_" diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs index b013028d4..ad9d90e8c 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/v_from_type.rs @@ -1,8 +1,5 @@ use crate::helix_engine::{ - traversal_core::{ - traversal_iter::RoTraversalIterator, - traversal_value::TraversalValue, - }, + traversal_core::{traversal_iter::RoTraversalIterator, traversal_value::TraversalValue}, types::GraphError, }; @@ -31,27 +28,31 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE fn v_from_type( self, label: &'arena str, - _get_vector_data: bool, + get_vector_data: bool, ) -> RoTraversalIterator< 'db, 'arena, 'txn, impl Iterator, GraphError>>, > { - let _label_bytes = label.as_bytes(); - let iter = self - .storage - .vectors - .vector_properties_db - .iter(self.txn) - .unwrap() - .filter_map(move |_item| todo!()); + let mut inner = Vec::new(); + match self.storage.vectors.get_all_vectors_by_label( + self.txn, + label, + get_vector_data, + self.arena, + ) { + Ok(vec) => vec + .into_iter() + .for_each(|v| inner.push(Ok(TraversalValue::Vector(v)))), + Err(err) => inner.push(Err(GraphError::from(err))), + } RoTraversalIterator { storage: self.storage, arena: self.arena, txn: self.txn, - inner: iter, + inner: inner.into_iter(), } } } diff --git a/helix-db/src/helix_engine/vector_core/item_iter.rs b/helix-db/src/helix_engine/vector_core/item_iter.rs index b0c82ae7d..2c19a33a0 100644 --- a/helix-db/src/helix_engine/vector_core/item_iter.rs +++ b/helix-db/src/helix_engine/vector_core/item_iter.rs @@ -1,17 +1,17 @@ use heed3::RoTxn; use crate::helix_engine::vector_core::{ - CoreDatabase, ItemId, LmdbResult, + CoreDatabase, LmdbResult, distance::Distance, key::{KeyCodec, Prefix, PrefixCodec}, node::{Item, Node, NodeCodec}, + node_id::NodeId, }; // used by the reader pub struct ItemIter<'t, D: Distance> { pub inner: heed3::RoPrefix<'t, KeyCodec, NodeCodec>, dimensions: usize, - arena: &'t bumpalo::Bump, } impl<'t, D: Distance> ItemIter<'t, D> { @@ -20,7 +20,6 @@ impl<'t, D: Distance> ItemIter<'t, D> { index: u16, dimensions: usize, rtxn: &'t RoTxn, - arena: &'t bumpalo::Bump, ) -> heed3::Result { Ok(ItemIter { inner: database @@ -28,24 +27,33 @@ impl<'t, D: Distance> ItemIter<'t, D> { .prefix_iter(rtxn, &Prefix::item(index))? .remap_key_type::(), dimensions, - arena, }) } + + pub fn next_id(&mut self) -> Option> { + match self.inner.next() { + Some(Ok((key, node))) => match node { + Node::Item(_) => Some(Ok(key.node)), + Node::Links(_) => unreachable!("Node must not be a link"), + }, + Some(Err(e)) => Some(Err(e)), + None => None, + } + } } impl<'t, D: Distance> Iterator for ItemIter<'t, D> { - type Item = LmdbResult<(ItemId, bumpalo::collections::Vec<'t, f32>)>; + type Item = LmdbResult<(NodeId, Item<'t, D>)>; fn next(&mut self) -> Option { match self.inner.next() { Some(Ok((key, node))) => match node { - Node::Item(Item { header: _, vector }) => { - let mut vector = vector.to_vec(self.arena); - if vector.len() != self.dimensions { + Node::Item(mut item) => { + if item.vector.len() != self.dimensions { // quantized codecs pad to 8-bytes so we truncate to recover len - vector.truncate(self.dimensions); + item.vector.to_mut().truncate(self.dimensions); } - Some(Ok((key.node.item, vector))) + Some(Ok((key.node, item))) } Node::Links(_) => unreachable!("Node must not be a link"), }, diff --git a/helix-db/src/helix_engine/vector_core/key.rs b/helix-db/src/helix_engine/vector_core/key.rs index 8d32ff318..73b82e0f9 100644 --- a/helix-db/src/helix_engine/vector_core/key.rs +++ b/helix-db/src/helix_engine/vector_core/key.rs @@ -7,7 +7,7 @@ use heed3::BoxedError; use crate::helix_engine::vector_core::node_id::{NodeId, NodeMode}; /// This whole structure must fit in an u64 so we can tell LMDB to optimize its storage. -/// The `index` is specified by the user and is used to differentiate between multiple hannoy indexes. +/// The `index` is specified by the user and is used to differentiate between multiple indexes. /// The `mode` indicates what we're looking at. /// The `item` point to a specific node. /// If the mode is: diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 55738c60d..28aea2b43 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -67,7 +67,8 @@ pub type CoreDatabase = heed3::Database>; pub struct HVector<'arena> { pub id: u128, pub distance: Option, - // TODO: String Interning + // TODO: String Interning. We do a lot of unnecessary string allocations + // for the same set of labels. pub label: &'arena str, pub deleted: bool, pub level: Option, @@ -219,12 +220,18 @@ impl Ord for HVector<'_> { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HNSWConfig { - pub m: usize, // max num of bi-directional links per element - pub m_max_0: usize, // max num of links for lower layers - pub ef_construct: usize, // size of the dynamic candidate list for construction - pub m_l: f64, // level generation factor - pub ef: usize, // search param, num of cands to search - pub min_neighbors: usize, // for get_neighbors, always 512 + /// max num of bi-directional links per element + pub m: usize, + /// max num of links for lower layers + pub m_max_0: usize, + /// size of the dynamic candidate list for construction + pub ef_construct: usize, + /// level generation factor + pub m_l: f64, + /// search param, num of cands to search + pub ef: usize, + /// for get_neighbors, always 512 + pub min_neighbors: usize, } impl HNSWConfig { @@ -359,6 +366,11 @@ impl VectorCore { // index hasn't been built yet let writer = self.get_writer_or_create_index(label, data.len(), txn)?; + let idx = self.curr_id.fetch_add(1, atomic::Ordering::SeqCst); + writer.add_item(txn, idx, data).inspect_err(|_| { + self.curr_id.fetch_sub(1, atomic::Ordering::SeqCst); + })?; + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); bump_vec.extend_from_slice(data); let hvector = HVector::from_vec(label, bump_vec); @@ -494,22 +506,127 @@ impl VectorCore { pub fn get_vector_properties<'arena>( &self, _txn: &RoTxn, - _id: u128, - _arena: &bumpalo::Bump, + id: u128, + arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { - todo!() + let global_to_local_id = self.global_to_local_id.read().unwrap(); + let (_, label) = global_to_local_id.get(&id).unwrap(); + + // todo: actually take properties + Ok(Some(HVector { + id, + distance: None, + label: arena.alloc_str(label.as_str()), + deleted: false, + version: 0, + level: None, + properties: None, + data: None, + })) } pub fn num_inserted_vectors(&self) -> usize { self.stats.num_vectors.load(atomic::Ordering::SeqCst) } + pub fn get_all_vectors_by_label<'arena>( + &self, + txn: &RoTxn, + label: &'arena str, + get_vector_data: bool, + arena: &'arena bumpalo::Bump, + ) -> VectorCoreResult>> { + let mut result = bumpalo::collections::Vec::new_in(arena); + let label_to_reader = self.label_to_reader.read().unwrap(); + let local_to_global_id = self.local_to_global_id.read().unwrap(); + + let reader = label_to_reader + .get(label) + .ok_or_else(|| VectorError::VectorCoreError("Label not found".to_string()))?; + + let mut iter = reader.iter(txn)?; + + if get_vector_data { + while let Some((key, item)) = iter.next().transpose()? { + let &id = local_to_global_id.get(&key.item).unwrap(); + result.push(HVector { + id, + label, + distance: None, + deleted: false, + level: Some(key.layer as usize), + version: 0, + properties: None, + data: Some(item.clone_in(arena)), + }); + } + } else { + while let Some(key) = iter.next_id().transpose()? { + let &id = local_to_global_id.get(&key.item).unwrap(); + result.push(HVector { + id, + label, + distance: None, + deleted: false, + level: Some(key.layer as usize), + version: 0, + properties: None, + data: None, + }); + } + } + + Ok(result) + } + pub fn get_all_vectors<'arena>( &self, - _txn: &RoTxn, - _level: Option, - _arena: &bumpalo::Bump, + txn: &RoTxn, + get_vector_data: bool, + arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { - todo!() + let label_to_reader = self.label_to_reader.read().unwrap(); + let local_to_global_id = self.local_to_global_id.read().unwrap(); + let mut result = bumpalo::collections::Vec::new_in(arena); + + for (label, _) in label_to_reader.iter() { + let reader = label_to_reader + .get(label) + .ok_or_else(|| VectorError::VectorCoreError("Label not found".to_string()))?; + + let mut iter = reader.iter(txn)?; + + if get_vector_data { + while let Some((key, item)) = iter.next().transpose()? { + let &id = local_to_global_id.get(&key.item).unwrap(); + result.push(HVector { + id, + label: arena.alloc_str(label), + distance: None, + deleted: false, + level: Some(key.layer as usize), + version: 0, + properties: None, + data: Some(item.clone_in(arena)), + }); + } + } else { + while let Some(key) = iter.next_id().transpose()? { + let &id = local_to_global_id.get(&key.item).unwrap(); + result.push(HVector { + id, + label: arena.alloc_str(label), + distance: None, + deleted: false, + level: Some(key.layer as usize), + version: 0, + properties: None, + data: None, + }); + } + } + } + + Ok(result) } } diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs index 6b4bc5f75..f2ef66cce 100644 --- a/helix-db/src/helix_engine/vector_core/writer.rs +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -136,17 +136,12 @@ impl Writer { } /// Returns an iterator over the items vector. - pub fn iter<'t>( - &self, - rtxn: &'t RoTxn, - arena: &'t bumpalo::Bump, - ) -> VectorCoreResult> { + pub fn iter<'t>(&self, rtxn: &'t RoTxn) -> VectorCoreResult> { Ok(ItemIter::new( self.database, self.index, self.dimensions, rtxn, - arena, )?) } diff --git a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs index 3215a62d8..2fbc8ed01 100644 --- a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs +++ b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs @@ -98,7 +98,7 @@ pub fn nodes_edges_inner(input: HandlerInput) -> Result = vecs .iter() diff --git a/hql-tests/run.sh b/hql-tests/run.sh old mode 100644 new mode 100755 From 5e8753556ec891a399969212007c10f73a9def9b Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 02:51:20 -0300 Subject: [PATCH 34/48] Temporarily rebuild on every insertion Yeah, this sucks. Althought the current approach incrementally rebuilds the index, it's way more expensive than it should --- .../storage_core/graph_visualization.rs | 2 +- helix-db/src/helix_engine/vector_core/mod.rs | 99 ++++++++++--------- .../src/helix_engine/vector_core/reader.rs | 19 ++-- .../src/helix_engine/vector_core/writer.rs | 4 +- 4 files changed, 64 insertions(+), 60 deletions(-) diff --git a/helix-db/src/helix_engine/storage_core/graph_visualization.rs b/helix-db/src/helix_engine/storage_core/graph_visualization.rs index 89eaad5bf..7c9a0924c 100644 --- a/helix-db/src/helix_engine/storage_core/graph_visualization.rs +++ b/helix-db/src/helix_engine/storage_core/graph_visualization.rs @@ -53,7 +53,7 @@ impl GraphVisualization for HelixGraphStorage { let result = json!({ "num_nodes": self.nodes_db.len(txn).unwrap_or(0), "num_edges": self.edges_db.len(txn).unwrap_or(0), - "num_vectors": self.vectors.stats.num_vectors, + "num_vectors": self.vectors.num_inserted_vectors(), }); debug_println!("db stats json: {:?}", result); diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 28aea2b43..a0c41aa92 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -258,20 +258,21 @@ impl HNSWConfig { } } -pub struct VectorCoreStats { +pub struct HnswIndex { + pub id: u16, + pub dimension: usize, pub num_vectors: AtomicUsize, } // TODO: Properties filters // TODO: Support different distances for each database pub struct VectorCore { - pub hsnw_index: CoreDatabase, - pub stats: VectorCoreStats, + pub hsnw: CoreDatabase, pub vector_properties_db: Database, Bytes>, pub config: HNSWConfig, - /// Map labels to a different index and dimension - pub label_to_index: RwLock>, + /// Map labels to a different [HnswIndex] + pub label_to_index: RwLock>, /// Track the last index curr_index: AtomicU16, @@ -291,10 +292,7 @@ impl VectorCore { .create(txn)?; Ok(Self { - hsnw_index: vectors_db, - stats: VectorCoreStats { - num_vectors: AtomicUsize::new(0), - }, + hsnw: vectors_db, vector_properties_db, config, label_to_index: RwLock::new(HashMap::new()), @@ -315,12 +313,12 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { match self.label_to_index.read().unwrap().get(label) { - Some(&(index, dimension)) => { - if dimension != query.len() { + Some(index) => { + if index.dimension != query.len() { return Err(VectorError::InvalidVectorLength); } - let reader = Reader::open(txn, index, self.hsnw_index)?; + let reader = Reader::open(txn, index.id, self.hsnw)?; reader.nns(k).by_vector(txn, query.as_slice(), arena) } None => Ok(Searched::new(bumpalo::vec![in &arena])), @@ -335,16 +333,20 @@ impl VectorCore { dimension: usize, txn: &mut RwTxn, ) -> VectorCoreResult> { - if let Some(&(idx, dimension)) = self.label_to_index.read().unwrap().get(label) { - Ok(Writer::new(self.hsnw_index, idx, dimension)) + if let Some(index) = self.label_to_index.read().unwrap().get(label) { + Ok(Writer::new(self.hsnw, index.id, dimension)) } else { // Index do not exist, we should build it let idx = self.curr_index.fetch_add(1, atomic::Ordering::SeqCst); - self.label_to_index - .write() - .unwrap() - .insert(label.to_string(), (idx, dimension)); - let writer = Writer::new(self.hsnw_index, idx, dimension); + self.label_to_index.write().unwrap().insert( + label.to_string(), + HnswIndex { + id: idx, + dimension, + num_vectors: AtomicUsize::new(0), + }, + ); + let writer = Writer::new(self.hsnw, idx, dimension); let mut rng = StdRng::from_os_rng(); let mut builder = writer.builder(&mut rng); @@ -363,7 +365,6 @@ impl VectorCore { _properties: Option>, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { - // index hasn't been built yet let writer = self.get_writer_or_create_index(label, data.len(), txn)?; let idx = self.curr_id.fetch_add(1, atomic::Ordering::SeqCst); @@ -384,11 +385,21 @@ impl VectorCore { .write() .unwrap() .insert(idx, hvector.id); - self.stats + self.label_to_index + .read() + .unwrap() + .get(label) + .unwrap() .num_vectors .fetch_add(1, atomic::Ordering::SeqCst); - writer.add_item(txn, idx, data)?; + let mut rng = StdRng::from_os_rng(); + let mut builder = writer.builder(&mut rng); + + // FIXME: We shouldn't rebuild on every insertion + builder + .ef_construction(self.config.ef_construct) + .build(txn)?; Ok(hvector) } @@ -396,17 +407,13 @@ impl VectorCore { pub fn delete(&self, txn: &mut RwTxn, id: u128) -> VectorCoreResult<()> { match self.global_to_local_id.read().unwrap().get(&id) { Some(&(idx, ref label)) => { - let &(index, dimension) = self - .label_to_index - .read() - .unwrap() + let label_to_index = self.label_to_index.read().unwrap(); + let index = label_to_index .get(label) .expect("if index exist label should also exist"); - let writer = Writer::new(self.hsnw_index, index, dimension); + let writer = Writer::new(self.hsnw, index.id, index.dimension); writer.del_item(txn, idx)?; - self.stats - .num_vectors - .fetch_add(1, atomic::Ordering::SeqCst); + index.num_vectors.fetch_add(1, atomic::Ordering::SeqCst); Ok(()) } None => Err(VectorError::VectorNotFound(format!( @@ -438,7 +445,7 @@ impl VectorCore { let (item_id, _) = nns.first().unwrap(); let global_id = local_to_global_id.get(item_id).unwrap(); let (_, label) = global_to_local_id.get(global_id).unwrap(); - let (index, _) = label_to_index.get(label).unwrap(); + let index = label_to_index.get(label).unwrap(); let label = arena.alloc_str(label); if with_data { @@ -453,7 +460,7 @@ impl VectorCore { level: None, version: 0, properties: None, - data: get_item(self.hsnw_index, *index, txn, item_id).unwrap(), + data: get_item(self.hsnw, index.id, txn, item_id).unwrap(), }); } } else { @@ -486,10 +493,10 @@ impl VectorCore { let global_to_local_id = self.global_to_local_id.read().unwrap(); let (item_id, label) = global_to_local_id.get(&id).unwrap(); - let (index, _) = label_to_index.get(label).unwrap(); + let index = label_to_index.get(label).unwrap(); let label = arena.alloc_str(label); - let item = get_item(self.hsnw_index, *index, txn, *item_id)?.map(|i| i.clone_in(arena)); + let item = get_item(self.hsnw, index.id, txn, *item_id)?.map(|i| i.clone_in(arena)); Ok(HVector { id, @@ -526,7 +533,12 @@ impl VectorCore { } pub fn num_inserted_vectors(&self) -> usize { - self.stats.num_vectors.load(atomic::Ordering::SeqCst) + self.label_to_index + .read() + .unwrap() + .iter() + .map(|(_, i)| i.num_vectors.load(atomic::Ordering::SeqCst)) + .sum() } pub fn get_all_vectors_by_label<'arena>( @@ -537,13 +549,11 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { let mut result = bumpalo::collections::Vec::new_in(arena); - let label_to_reader = self.label_to_reader.read().unwrap(); let local_to_global_id = self.local_to_global_id.read().unwrap(); + let label_to_index = self.label_to_index.read().unwrap(); + let index = label_to_index.get(label).unwrap(); - let reader = label_to_reader - .get(label) - .ok_or_else(|| VectorError::VectorCoreError("Label not found".to_string()))?; - + let reader = Reader::open(txn, index.id, self.hsnw)?; let mut iter = reader.iter(txn)?; if get_vector_data { @@ -585,15 +595,12 @@ impl VectorCore { get_vector_data: bool, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { - let label_to_reader = self.label_to_reader.read().unwrap(); + let label_to_index = self.label_to_index.read().unwrap(); let local_to_global_id = self.local_to_global_id.read().unwrap(); let mut result = bumpalo::collections::Vec::new_in(arena); - for (label, _) in label_to_reader.iter() { - let reader = label_to_reader - .get(label) - .ok_or_else(|| VectorError::VectorCoreError("Label not found".to_string()))?; - + for (label, index) in label_to_index.iter() { + let reader = Reader::open(txn, index.id, self.hsnw)?; let mut iter = reader.iter(txn)?; if get_vector_data { diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index 70867848b..85466fccb 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -223,9 +223,10 @@ impl<'a> Visitor<'a> { if res.len() < self.ef || dist < f_max { search_queue.push((Reverse(OrderedFloat(dist)), point)); if let Some(c) = self.candidates - && !c.contains(point) { - continue; - } + && !c.contains(point) + { + continue; + } if res.len() == self.ef { let _ = res.push_pop_max((OrderedFloat(dist), point)); } else { @@ -488,8 +489,8 @@ impl Reader { } /// Returns `true` if the index is empty. - pub fn is_empty(&self, rtxn: &RoTxn, arena: &bumpalo::Bump) -> VectorCoreResult { - self.iter(rtxn, arena).map(|mut iter| iter.next().is_none()) + pub fn is_empty(&self, rtxn: &RoTxn) -> VectorCoreResult { + self.iter(rtxn).map(|mut iter| iter.next().is_none()) } /// Returns `true` if the database contains the given item. @@ -502,12 +503,8 @@ impl Reader { } /// Returns an iterator over the items vector. - pub fn iter<'t>( - &self, - rtxn: &'t RoTxn, - arena: &'t bumpalo::Bump, - ) -> VectorCoreResult> { - ItemIter::new(self.database, self.index, self.dimensions, rtxn, arena).map_err(Into::into) + pub fn iter<'t>(&self, rtxn: &'t RoTxn) -> VectorCoreResult> { + ItemIter::new(self.database, self.index, self.dimensions, rtxn).map_err(Into::into) } /// Return a [`QueryBuilder`] that lets you configure and execute a search request. diff --git a/helix-db/src/helix_engine/vector_core/writer.rs b/helix-db/src/helix_engine/vector_core/writer.rs index f2ef66cce..36a7572dc 100644 --- a/helix-db/src/helix_engine/vector_core/writer.rs +++ b/helix-db/src/helix_engine/vector_core/writer.rs @@ -106,8 +106,8 @@ impl Writer { } /// Returns `true` if the index is empty. - pub fn is_empty(&self, rtxn: &RoTxn, arena: &bumpalo::Bump) -> VectorCoreResult { - self.iter(rtxn, arena).map(|mut iter| iter.next().is_none()) + pub fn is_empty(&self, rtxn: &RoTxn) -> VectorCoreResult { + self.iter(rtxn).map(|mut iter| iter.next().is_none()) } /// Returns `true` if the index needs to be built before being able to read in it. From 576a7dd95ab95405df892beef26c5e2f91d5a59e Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 11:02:28 -0300 Subject: [PATCH 35/48] Get serialization right --- helix-db/src/helix_engine/vector_core/mod.rs | 54 +++++++++++++++++-- .../custom_serde/vector_serde_tests.rs | 6 +-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index a0c41aa92..e6b60144b 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -31,7 +31,10 @@ use crate::{ custom_serde::vector_serde::{VectoWithoutDataDeSeed, VectorDeSeed}, value::Value, }, - utils::{id::v6_uuid, properties::ImmutablePropertiesMap}, + utils::{ + id::{uuid_str_from_buf, v6_uuid}, + properties::ImmutablePropertiesMap, + }, }; pub mod distance; @@ -63,7 +66,7 @@ pub type LmdbResult = std::result::Result; pub type CoreDatabase = heed3::Database>; -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Clone)] pub struct HVector<'arena> { pub id: u128, pub distance: Option, @@ -77,6 +80,42 @@ pub struct HVector<'arena> { pub data: Option>, } +impl<'arena> serde::Serialize for HVector<'arena> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::SerializeStruct; + + // Check if this is a human-readable format (like JSON) + if serializer.is_human_readable() { + // Include id for JSON serialization + let mut buffer = [0u8; 36]; + let mut state = serializer.serialize_map(Some( + 5 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0), + ))?; + state.serialize_entry("id", uuid_str_from_buf(self.id, &mut buffer))?; + state.serialize_entry("label", &self.label)?; + state.serialize_entry("version", &self.version)?; + state.serialize_entry("deleted", &self.deleted)?; + if let Some(properties) = &self.properties { + for (key, value) in properties.iter() { + state.serialize_entry(key, value)?; + } + } + state.end() + } else { + // Skip id, level, distance, and data for bincode serialization + let mut state = serializer.serialize_struct("HVector", 4)?; + state.serialize_field("label", &self.label)?; + state.serialize_field("version", &self.version)?; + state.serialize_field("deleted", &self.deleted)?; + state.serialize_field("properties", &self.properties)?; + state.end() + } + } +} + impl<'arena> HVector<'arena> { pub fn data_borrowed(&self) -> &[f32] { bytemuck::cast_slice(self.data.as_ref().unwrap().vector.as_bytes()) @@ -101,10 +140,16 @@ impl<'arena> HVector<'arena> { } /// Converts HVector's data to a vec of bytes by accessing the data field directly - /// and converting each f64 to a byte slice + /// and converting each f32 to a byte slice #[inline(always)] pub fn vector_data_to_bytes(&self) -> VectorCoreResult<&[u8]> { - Ok(self.data.as_ref().unwrap().vector.as_ref().as_bytes()) + Ok(self + .data + .as_ref() + .ok_or(VectorError::HasNoData)? + .vector + .as_ref() + .as_bytes()) } /// Deserializes bytes into an vector using a custom deserializer that allocates into the provided arena @@ -112,7 +157,6 @@ impl<'arena> HVector<'arena> { /// Both the properties bytes (if present) and the raw vector data are combined to generate the final vector struct /// /// NOTE: in this method, fixint encoding is used - #[inline] pub fn from_bincode_bytes<'txn>( arena: &'arena bumpalo::Bump, properties: Option<&'txn [u8]>, diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index 36fe41c98..474b2755d 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -142,7 +142,7 @@ mod vector_serialization_tests { let vector = create_simple_vector(&arena, id, "vector_128", &data); let bytes = vector.vector_data_to_bytes().unwrap(); - assert_eq!(bytes.len(), 128 * 8); // 128 dimensions * 8 bytes per f32 + assert_eq!(bytes.len(), 128 * 4); // 128 dimensions * 4 bytes per f32 } #[test] @@ -154,7 +154,7 @@ mod vector_serialization_tests { let vector = create_simple_vector(&arena, id, "vector_384", &data); let bytes = vector.vector_data_to_bytes().unwrap(); - assert_eq!(bytes.len(), 384 * 8); + assert_eq!(bytes.len(), 384 * 4); } #[test] @@ -166,7 +166,7 @@ mod vector_serialization_tests { let vector = create_simple_vector(&arena, id, "vector_1536", &data); let bytes = vector.vector_data_to_bytes().unwrap(); - assert_eq!(bytes.len(), 1536 * 8); + assert_eq!(bytes.len(), 1536 * 4); } #[test] From 3415c232e878abf384d32b78b3467b9952e7f959 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 11:06:26 -0300 Subject: [PATCH 36/48] Make copy and zero-copy HVector functions explicity --- .../ops/vectors/brute_force_search.rs | 9 +++---- helix-db/src/helix_engine/vector_core/mod.rs | 27 +++++++++++++------ helix-db/src/helix_engine/vector_core/node.rs | 21 +++++++++++++-- .../builtin/all_nodes_and_edges.rs | 2 +- .../src/protocol/custom_serde/test_utils.rs | 4 +-- .../src/protocol/custom_serde/vector_serde.rs | 2 +- .../custom_serde/vector_serde_tests.rs | 2 +- 7 files changed, 47 insertions(+), 20 deletions(-) diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs index 2ddc20f14..95ae6e1fb 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/brute_force_search.rs @@ -48,11 +48,10 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE .inner .filter_map(|v| match v { Ok(TraversalValue::Vector(mut v)) => { - let mut bump_vec = bumpalo::collections::Vec::new_in(self.arena); - bump_vec.extend_from_slice(v.data_borrowed()); - - let d = - Cosine::distance(v.data.as_ref().unwrap(), &Item::::new(bump_vec)); + let d = Cosine::distance( + v.data.as_ref().unwrap(), + &Item::::from_slice(v.data_borrowed()), + ); v.set_distance(d); Some(v) } diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index e6b60144b..5f716ef17 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -14,7 +14,7 @@ use heed3::{ types::{Bytes, U128}, }; use rand::{SeedableRng, rngs::StdRng}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, Serializer, ser::SerializeMap}; use crate::{ helix_engine::{ @@ -127,7 +127,7 @@ impl<'arena> HVector<'arena> { id, label, version: 1, - data: Some(Item::::new(data)), + data: Some(Item::::from_vec(data)), distance: None, properties: None, deleted: false, @@ -233,12 +233,23 @@ impl<'arena> HVector<'arena> { } pub fn from_raw_vector_data<'txn>( - _arena: &'arena bumpalo::Bump, - _raw_vector_data: &'txn [u8], - _label: &'arena str, - _id: u128, - ) -> Result { - todo!() + id: u128, + label: &'arena str, + raw_vector_data: &'txn [u8], + ) -> VectorCoreResult> + where + 'arena: 'txn, + { + Ok(HVector { + id, + label, + data: Some(Item::::from_raw_slice(raw_vector_data)), + properties: None, + distance: None, + deleted: false, + level: Some(0), + version: 1, + }) } } diff --git a/helix-db/src/helix_engine/vector_core/node.rs b/helix-db/src/helix_engine/vector_core/node.rs index 21d0bc698..dc0b7880f 100644 --- a/helix-db/src/helix_engine/vector_core/node.rs +++ b/helix-db/src/helix_engine/vector_core/node.rs @@ -66,7 +66,7 @@ impl Clone for Item<'_, D> { } } -impl Item<'_, D> { +impl<'a, D: Distance> Item<'a, D> { /// Converts the item into an owned version of itself by cloning /// the internal vector. Doing so will make it mutable. pub fn into_owned(self) -> Item<'static, D> { @@ -89,8 +89,25 @@ impl Item<'_, D> { } } + /// Builds a new borrowed item from a `&[u8]` slice. + /// This function do not allocates + pub fn from_raw_slice(slice: &'a [u8]) -> Self { + let vector = UnalignedVector::from_slice(bytemuck::cast_slice(slice)); + let header = D::new_header(&vector); + Self { header, vector } + } + + /// Builds a new borrowed item from a `&[f32]` slice. + /// This function do not allocates + pub fn from_slice(slice: &'a [f32]) -> Self { + let vector = UnalignedVector::from_slice(slice); + let header = D::new_header(&vector); + Self { header, vector } + } + /// Builds a new item from a `Vec`. - pub fn new(vec: bumpalo::collections::Vec) -> Self { + /// This function allocates + pub fn from_vec(vec: bumpalo::collections::Vec) -> Self { let vector = UnalignedVector::from_vec(vec); let header = D::new_header(&vector); Self { header, vector } diff --git a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs index 2fbc8ed01..303aa5e68 100644 --- a/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs +++ b/helix-db/src/helix_gateway/builtin/all_nodes_and_edges.rs @@ -98,7 +98,7 @@ pub fn nodes_edges_inner(input: HandlerInput) -> Result = vecs .iter() diff --git a/helix-db/src/protocol/custom_serde/test_utils.rs b/helix-db/src/protocol/custom_serde/test_utils.rs index a9218aaec..fc3faa550 100644 --- a/helix-db/src/protocol/custom_serde/test_utils.rs +++ b/helix-db/src/protocol/custom_serde/test_utils.rs @@ -242,7 +242,7 @@ pub fn create_arena_vector<'arena>( version, deleted, distance: None, - data: Some(Item::::new(bump_vec)), + data: Some(Item::::from_vec(bump_vec)), properties: None, level: None, } @@ -260,7 +260,7 @@ pub fn create_arena_vector<'arena>( version, deleted, distance: None, - data: Some(Item::::new(bump_vec)), + data: Some(Item::::from_vec(bump_vec)), properties: Some(props_map), level: None, } diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index 42121a958..4e7fb214b 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -102,7 +102,7 @@ impl<'de, 'txn, 'arena> serde::de::DeserializeSeed<'de> for VectorDeSeed<'txn, ' deleted, version, distance: None, - data: Some(Item::::new(data)), + data: Some(Item::::from_vec(data)), properties, level: None, }) diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index 474b2755d..5d0ad19d1 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -205,7 +205,7 @@ mod vector_serialization_tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let raw_bytes = create_vector_bytes(&data); - let vector = HVector::from_raw_vector_data(&arena, &raw_bytes, label, id).unwrap(); + let vector = HVector::from_raw_vector_data(id, label, &raw_bytes).unwrap(); assert_eq!(vector.id, id); assert_eq!(vector.label, label); From 6844f905e1ed8421684c45a4138fe24123322383 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 11:13:32 -0300 Subject: [PATCH 37/48] Implement distance_to Only supports Cosine for now, but we definitely want to support more distance methods --- helix-db/src/helix_engine/vector_core/distance/mod.rs | 6 +++--- helix-db/src/helix_engine/vector_core/mod.rs | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/distance/mod.rs b/helix-db/src/helix_engine/vector_core/distance/mod.rs index 6ea16cf88..786726713 100644 --- a/helix-db/src/helix_engine/vector_core/distance/mod.rs +++ b/helix-db/src/helix_engine/vector_core/distance/mod.rs @@ -13,9 +13,9 @@ mod cosine; pub type DistanceValue = f32; -pub const MAX_DISTANCE: f64 = 2.0; -pub const ORTHOGONAL: f64 = 1.0; -pub const MIN_DISTANCE: f64 = 0.0; +pub const MAX_DISTANCE: f32 = 2.0; +pub const ORTHOGONAL: f32 = 1.0; +pub const MIN_DISTANCE: f32 = 0.0; pub trait Distance: Send + Sync + Sized + Clone + fmt::Debug + 'static { /// A header structure with informations related to the diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 5f716ef17..d23806910 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -20,7 +20,7 @@ use crate::{ helix_engine::{ types::VectorError, vector_core::{ - distance::Cosine, + distance::{Cosine, Distance}, key::KeyCodec, node::{Item, NodeCodec}, reader::{Reader, Searched, get_item}, @@ -198,8 +198,11 @@ impl<'arena> HVector<'arena> { bincode::serialize(self) } - pub fn distance_to(&self, _rhs: &HVector<'arena>) -> VectorCoreResult { - todo!() + pub fn distance_to(&self, rhs: &HVector<'arena>) -> VectorCoreResult { + match (self.data.as_ref(), rhs.data.as_ref()) { + (None, _) | (_, None) => Err(VectorError::HasNoData), + (Some(a), Some(b)) => Ok(Cosine::distance(a, b)), + } } pub fn set_distance(&mut self, distance: f32) { From 49dc89aafd021300b6f4b852b5a372c34a43d58d Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 12:24:21 -0300 Subject: [PATCH 38/48] Update compiler to default f32 vector type --- .../analyzer/methods/infer_expr_type.rs | 37 +++--- .../analyzer/methods/query_validation.rs | 107 ++++++++++++------ .../helixc/analyzer/methods/schema_methods.rs | 8 +- .../analyzer/methods/traversal_validation.rs | 4 +- helix-db/src/helixc/analyzer/utils.rs | 6 +- .../src/helixc/generator/math_functions.rs | 99 ++++++++++------ helix-db/src/helixc/generator/queries.rs | 8 +- helix-db/src/helixc/generator/schemas.rs | 4 +- helix-db/src/helixc/generator/source_steps.rs | 2 +- .../src/helixc/generator/traversal_steps.rs | 43 +++++-- helix-db/src/helixc/generator/utils.rs | 2 +- .../src/helixc/parser/object_parse_methods.rs | 6 +- .../helixc/parser/traversal_parse_methods.rs | 4 +- helix-db/src/helixc/parser/types.rs | 62 +++++++--- helix-db/src/helixc/parser/utils.rs | 4 +- 15 files changed, 252 insertions(+), 144 deletions(-) diff --git a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs index 278c20be7..abb8cfd9d 100644 --- a/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs +++ b/helix-db/src/helixc/analyzer/methods/infer_expr_type.rs @@ -87,7 +87,7 @@ pub(crate) fn infer_expr_type<'a>( Some(GeneratedStatement::Literal(GenRef::Literal(i.to_string()))), ), FloatLiteral(f) => ( - Type::Scalar(FieldType::F64), + Type::Scalar(FieldType::F32), Some(GeneratedStatement::Literal(GenRef::Literal(f.to_string()))), ), StringLiteral(s) => ( @@ -616,12 +616,10 @@ pub(crate) fn infer_expr_type<'a>( Some(properties.into_iter().collect()) } - None => { - match default_properties.is_empty() { - true => None, - false => Some(default_properties), - } - } + None => match default_properties.is_empty() { + true => None, + false => Some(default_properties), + }, }; let (to, to_is_plural) = match &add.connection.to_id { @@ -773,7 +771,12 @@ pub(crate) fn infer_expr_type<'a>( } let label = GenRef::Literal(ty.clone()); - let vector_in_schema = match ctx.output.vectors.iter().find(|v| v.name == ty.as_str()) { + let vector_in_schema = match ctx + .output + .vectors + .iter() + .find(|v| v.name == ty.as_str()) + { Some(vector) => vector.clone(), None => { generate_error!(ctx, original_query, add.loc.clone(), E103, ty.as_str()); @@ -971,15 +974,13 @@ pub(crate) fn infer_expr_type<'a>( properties } - None => { - default_properties.into_iter().fold( - HashMap::new(), - |mut acc, (field_name, default_value)| { - acc.insert(field_name, default_value); - acc - }, - ) - } + None => default_properties.into_iter().fold( + HashMap::new(), + |mut acc, (field_name, default_value)| { + acc.insert(field_name, default_value); + acc + }, + ), }; if let Some(vec_data) = &add.data { let vec = match vec_data { @@ -1391,7 +1392,7 @@ pub(crate) fn infer_expr_type<'a>( // Math function calls always return f64 // TODO: Add proper type inference and validation for math function arguments ( - Type::Scalar(FieldType::F64), + Type::Scalar(FieldType::F32), None, // Will be handled by generator ) } diff --git a/helix-db/src/helixc/analyzer/methods/query_validation.rs b/helix-db/src/helixc/analyzer/methods/query_validation.rs index c5502a073..4c95cc817 100644 --- a/helix-db/src/helixc/analyzer/methods/query_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/query_validation.rs @@ -14,8 +14,7 @@ use crate::helixc::{ generator::{ queries::{Parameter as GeneratedParameter, Query as GeneratedQuery}, return_values::{ - ReturnFieldInfo, ReturnFieldSource, ReturnFieldType, ReturnValue, - ReturnValueStruct, + ReturnFieldInfo, ReturnFieldSource, ReturnFieldType, ReturnValue, ReturnValueStruct, }, source_steps::SourceStep, statements::Statement as GeneratedStatement, @@ -167,13 +166,13 @@ fn build_return_fields( if should_add_field("data") { fields.push(ReturnFieldInfo::new_implicit( "data".to_string(), - "&'a [f64]".to_string(), + "&'a [f32]".to_string(), )); } if should_add_field("score") { fields.push(ReturnFieldInfo::new_implicit( "score".to_string(), - "f64".to_string(), + "f32".to_string(), )); } } @@ -235,8 +234,8 @@ fn build_return_fields( if is_implicit_field { let rust_type = match *field_name { - "data" => "&'a [f64]".to_string(), - "score" => "f64".to_string(), + "data" => "&'a [f32]".to_string(), + "score" => "f32".to_string(), _ => "&'a str".to_string(), }; fields.push(ReturnFieldInfo::new_implicit( @@ -305,8 +304,8 @@ fn build_return_fields( let rust_type = if is_implicit { // Use the appropriate type based on the implicit field match accessed_field.map(|s| s.as_str()) { - Some("data") => "&'a [f64]".to_string(), - Some("score") => "f64".to_string(), + Some("data") => "&'a [f32]".to_string(), + Some("score") => "f32".to_string(), Some("id") | Some("ID") | Some("label") | Some("Label") | Some("from_node") | Some("to_node") | None => "&'a str".to_string(), _ => "Option<&'a Value>".to_string(), @@ -434,13 +433,17 @@ fn process_object_literal<'a>( // Handle traversal like app::{name} // Extract variable name from start node let var_name = match &trav.start { - crate::helixc::parser::types::StartNode::Identifier(id) => id.clone(), + crate::helixc::parser::types::StartNode::Identifier(id) => { + id.clone() + } _ => "unknown".to_string(), }; // Check if there's an Object step to extract property name if let Some(step) = trav.steps.first() { - if let crate::helixc::parser::types::StepType::Object(obj) = &step.step { + if let crate::helixc::parser::types::StepType::Object(obj) = + &step.step + { // Extract the first field name from the object step if let Some(field) = obj.fields.first() { let prop_name = &field.key; @@ -475,9 +478,7 @@ fn process_object_literal<'a>( format!("json!({})", id) } } - _ => { - "serde_json::Value::Null".to_string() - } + _ => "serde_json::Value::Null".to_string(), } } ReturnType::Object(nested_obj) => { @@ -495,7 +496,11 @@ fn process_object_literal<'a>( ExpressionType::Identifier(id) => { // Look up the variable type and generate property extraction if let Some(var_info) = scope.get(id.as_str()) { - array_parts.push(build_identifier_json(ctx, id, &var_info.ty)); + array_parts.push(build_identifier_json( + ctx, + id, + &var_info.ty, + )); } else { // Fallback array_parts.push(format!("json!({})", id)); @@ -504,24 +509,37 @@ fn process_object_literal<'a>( ExpressionType::Traversal(trav) => { // Handle traversal in array let var_name = match &trav.start { - crate::helixc::parser::types::StartNode::Identifier(id) => id.clone(), + crate::helixc::parser::types::StartNode::Identifier( + id, + ) => id.clone(), _ => "unknown".to_string(), }; // Check for object step if let Some(step) = trav.steps.first() { - if let crate::helixc::parser::types::StepType::Object(obj) = &step.step { + if let crate::helixc::parser::types::StepType::Object( + obj, + ) = &step.step + { if let Some(field) = obj.fields.first() { let prop_name = &field.key; if prop_name == "id" { - array_parts.push(format!("uuid_str({}.id(), &arena)", var_name)); + array_parts.push(format!( + "uuid_str({}.id(), &arena)", + var_name + )); } else if prop_name == "label" { - array_parts.push(format!("{}.label()", var_name)); + array_parts + .push(format!("{}.label()", var_name)); } else { - array_parts.push(format!("{}.get_property(\"{}\")", var_name, prop_name)); + array_parts.push(format!( + "{}.get_property(\"{}\")", + var_name, prop_name + )); } } else { - array_parts.push(format!("json!({})", var_name)); + array_parts + .push(format!("json!({})", var_name)); } } else { array_parts.push(format!("json!({})", var_name)); @@ -571,13 +589,19 @@ fn process_object_literal<'a>( if *prop_name == "id" || *prop_name == "label" { continue; } - props.push(format!("\"{}\": {}.get_property(\"{}\")", prop_name, var_name, prop_name)); + props.push(format!( + "\"{}\": {}.get_property(\"{}\")", + prop_name, var_name, prop_name + )); } format!("json!({{\n {}\n }})", props.join(",\n ")) } else { // Fallback if schema not found - format!("json!({{\"id\": uuid_str({}.id(), &arena), \"label\": {}.label()}})", var_name, var_name) + format!( + "json!({{\"id\": uuid_str({}.id(), &arena), \"label\": {}.label()}})", + var_name, var_name + ) } } Type::Edge(Some(label)) => { @@ -593,12 +617,18 @@ fn process_object_literal<'a>( if *prop_name == "id" || *prop_name == "label" { continue; } - props.push(format!("\"{}\": {}.get_property(\"{}\")", prop_name, var_name, prop_name)); + props.push(format!( + "\"{}\": {}.get_property(\"{}\")", + prop_name, var_name, prop_name + )); } format!("json!({{\n {}\n }})", props.join(",\n ")) } else { - format!("json!({{\"id\": uuid_str({}.id(), &arena), \"label\": {}.label()}})", var_name, var_name) + format!( + "json!({{\"id\": uuid_str({}.id(), &arena), \"label\": {}.label()}})", + var_name, var_name + ) } } _ => { @@ -616,7 +646,10 @@ fn process_object_literal<'a>( ReturnValue { name: "serde_json::Value".to_string(), fields: vec![], - literal_value: Some(crate::helixc::generator::utils::GenRef::Std(format!("json!({})", json_code))), + literal_value: Some(crate::helixc::generator::utils::GenRef::Std(format!( + "json!({})", + json_code + ))), }, )); @@ -939,11 +972,20 @@ fn analyze_return_expr<'a>( ShouldCollect::ToVec => { // Collection - generate iteration code let iter_code = if property_name == "id" { - format!("{}.iter().map(|item| uuid_str(item.id(), &arena)).collect::>()", field_name) + format!( + "{}.iter().map(|item| uuid_str(item.id(), &arena)).collect::>()", + field_name + ) } else if property_name == "label" { - format!("{}.iter().map(|item| item.label()).collect::>()", field_name) + format!( + "{}.iter().map(|item| item.label()).collect::>()", + field_name + ) } else { - format!("{}.iter().map(|item| item.get_property(\"{}\")).collect::>()", field_name, property_name) + format!( + "{}.iter().map(|item| item.get_property(\"{}\")).collect::>()", + field_name, property_name + ) }; Some(GenRef::Std(iter_code)) } @@ -1233,14 +1275,7 @@ fn analyze_return_expr<'a>( } else { // Complex nested object - use new object literal processing let struct_name = format!("{}ReturnType", capitalize_first(&query.name)); - process_object_literal( - ctx, - original_query, - scope, - query, - values, - struct_name, - ); + process_object_literal(ctx, original_query, scope, query, values, struct_name); // Note: process_object_literal adds to query.return_values // and sets use_struct_returns = false, so no need to push to return_structs diff --git a/helix-db/src/helixc/analyzer/methods/schema_methods.rs b/helix-db/src/helixc/analyzer/methods/schema_methods.rs index b77bab181..da1417b09 100644 --- a/helix-db/src/helixc/analyzer/methods/schema_methods.rs +++ b/helix-db/src/helixc/analyzer/methods/schema_methods.rs @@ -158,7 +158,7 @@ pub(crate) fn build_field_lookups<'a>(src: &'a Source) -> SchemaVersionMap<'a> { prefix: FieldPrefix::Empty, defaults: None, name: "data".to_string(), - field_type: FieldType::Array(Box::new(FieldType::F64)), + field_type: FieldType::Array(Box::new(FieldType::F32)), loc: Loc::empty(), }), ); @@ -168,7 +168,7 @@ pub(crate) fn build_field_lookups<'a>(src: &'a Source) -> SchemaVersionMap<'a> { prefix: FieldPrefix::Empty, defaults: None, name: "score".to_string(), - field_type: FieldType::F64, + field_type: FieldType::F32, loc: Loc::empty(), }), ); @@ -685,7 +685,7 @@ mod tests { N::Person { name: String, age: U32, - score: F64, + score: F32, active: Boolean, user_id: ID, created_at: Date @@ -711,7 +711,7 @@ mod tests { let source = r#" N::Person { tags: [String], - scores: [F64], + scores: [F32], ids: [ID] } diff --git a/helix-db/src/helixc/analyzer/methods/traversal_validation.rs b/helix-db/src/helixc/analyzer/methods/traversal_validation.rs index 614cf2560..c0f943122 100644 --- a/helix-db/src/helixc/analyzer/methods/traversal_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/traversal_validation.rs @@ -77,7 +77,7 @@ fn get_reserved_property_type(prop_name: &str, item_type: &Type) -> Option { // Only valid for vectors match item_type { - Type::Vector(_) | Type::Vectors(_) => Some(FieldType::F64), + Type::Vector(_) | Type::Vectors(_) => Some(FieldType::F32), _ => None, } } @@ -85,7 +85,7 @@ fn get_reserved_property_type(prop_name: &str, item_type: &Type) -> Option { - Some(FieldType::Array(Box::new(FieldType::F64))) + Some(FieldType::Array(Box::new(FieldType::F32))) } _ => None, } diff --git a/helix-db/src/helixc/analyzer/utils.rs b/helix-db/src/helixc/analyzer/utils.rs index 4a66c66a3..44810d418 100644 --- a/helix-db/src/helixc/analyzer/utils.rs +++ b/helix-db/src/helixc/analyzer/utils.rs @@ -5,7 +5,7 @@ use crate::{ helixc::{ analyzer::{Ctx, errors::push_query_err, types::Type}, generator::{ - traversal_steps::{Step, ReservedProp}, + traversal_steps::{ReservedProp, Step}, utils::{GenRef, GeneratedValue}, }, parser::{location::Loc, types::*}, @@ -377,8 +377,8 @@ impl FieldLookup for Type { .map(|fields| match key { "id" | "ID" => Some(FieldType::Uuid), "label" => Some(FieldType::String), - "data" => Some(FieldType::Array(Box::new(FieldType::F64))), - "score" => Some(FieldType::F64), + "data" => Some(FieldType::Array(Box::new(FieldType::F32))), + "score" => Some(FieldType::F32), _ => fields .get(key) .map(|field| Some(field.field_type.clone())) diff --git a/helix-db/src/helixc/generator/math_functions.rs b/helix-db/src/helixc/generator/math_functions.rs index c1e086d6b..3aade7204 100644 --- a/helix-db/src/helixc/generator/math_functions.rs +++ b/helix-db/src/helixc/generator/math_functions.rs @@ -31,7 +31,7 @@ pub struct MathFunctionCallGen { #[derive(Debug, Clone)] pub struct NumericLiteral { - pub value: f64, + pub value: f32, } #[derive(Debug, Clone)] @@ -54,10 +54,10 @@ impl Display for MathExpr { impl Display for NumericLiteral { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Handle special formatting for cleaner output - if self.value.fract() == 0.0 && self.value.abs() < i64::MAX as f64 { - write!(f, "{}_f64", self.value as i64) + if self.value.fract() == 0.0 && self.value.abs() < i64::MAX as f32 { + write!(f, "{}_f32", self.value as i64) } else { - write!(f, "{}_f64", self.value) + write!(f, "{}_f32", self.value) } } } @@ -66,16 +66,32 @@ impl Display for PropertyAccess { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.context { PropertyContext::Edge => { - write!(f, "(edge.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property) + write!( + f, + "(edge.get_property({}).ok_or(GraphError::Default)?.as_f32())", + self.property + ) } PropertyContext::SourceNode => { - write!(f, "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property) + write!( + f, + "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", + self.property + ) } PropertyContext::TargetNode => { - write!(f, "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property) + write!( + f, + "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", + self.property + ) } PropertyContext::Current => { - write!(f, "(v.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property) + write!( + f, + "(v.get_property({}).ok_or(GraphError::Default)?.as_f32())", + self.property + ) } } } @@ -223,8 +239,8 @@ impl Display for MathFunctionCallGen { } // Constants (nullary) - MathFunction::Pi => write!(f, "std::f64::consts::PI"), - MathFunction::E => write!(f, "std::f64::consts::E"), + MathFunction::Pi => write!(f, "std::f32::consts::PI"), + MathFunction::E => write!(f, "std::f32::consts::E"), // Aggregates (special handling needed) MathFunction::Min @@ -270,7 +286,7 @@ pub fn generate_math_expr( })) } ExpressionType::IntegerLiteral(i) => Ok(MathExpr::NumericLiteral(NumericLiteral { - value: *i as f64, + value: *i as f32, })), ExpressionType::FloatLiteral(f) => { Ok(MathExpr::NumericLiteral(NumericLiteral { value: *f })) @@ -307,22 +323,35 @@ fn parse_property_access_from_traversal( } else if traversal.steps.len() == 2 { // Check if first step is FromN or ToN match &traversal.steps[0].step { - StepType::Node(graph_step) => { - match &graph_step.step { - GraphStepType::FromN => (PropertyContext::SourceNode, 1), - GraphStepType::ToN => (PropertyContext::TargetNode, 1), - _ => return Err(format!("Unexpected node step type in property access: {:?}", graph_step.step)), + StepType::Node(graph_step) => match &graph_step.step { + GraphStepType::FromN => (PropertyContext::SourceNode, 1), + GraphStepType::ToN => (PropertyContext::TargetNode, 1), + _ => { + return Err(format!( + "Unexpected node step type in property access: {:?}", + graph_step.step + )); } + }, + _ => { + return Err(format!( + "Expected FromN or ToN step, got: {:?}", + traversal.steps[0].step + )); } - _ => return Err(format!("Expected FromN or ToN step, got: {:?}", traversal.steps[0].step)), } } else { - return Err(format!("Invalid traversal length for property access: {}", traversal.steps.len())); + return Err(format!( + "Invalid traversal length for property access: {}", + traversal.steps.len() + )); }; // Extract property name from the Object step if let StepType::Object(obj) = &traversal.steps[property_step_idx].step - && obj.fields.len() == 1 && !obj.should_spread { + && obj.fields.len() == 1 + && !obj.should_spread + { let property_name = obj.fields[0].key.clone(); // Override context if specified by ExpressionContext @@ -347,13 +376,13 @@ mod tests { #[test] fn test_numeric_literal_integer() { let lit = NumericLiteral { value: 5.0 }; - assert_eq!(lit.to_string(), "5_f64"); + assert_eq!(lit.to_string(), "5_f32"); } #[test] fn test_numeric_literal_float() { let lit = NumericLiteral { value: 3.14 }; - assert_eq!(lit.to_string(), "3.14_f64"); + assert_eq!(lit.to_string(), "3.14_f32"); } #[test] @@ -365,7 +394,7 @@ mod tests { MathExpr::NumericLiteral(NumericLiteral { value: 3.0 }), ], }; - assert_eq!(add.to_string(), "(5_f64 + 3_f64)"); + assert_eq!(add.to_string(), "(5_f32 + 3_f32)"); } #[test] @@ -377,7 +406,7 @@ mod tests { MathExpr::NumericLiteral(NumericLiteral { value: 30.0 }), ], }; - assert_eq!(pow.to_string(), "(0.95_f64).powf(30_f64)"); + assert_eq!(pow.to_string(), "(0.95_f32).powf(30_f32)"); } #[test] @@ -395,7 +424,7 @@ mod tests { }), ], }; - assert_eq!(nested.to_string(), "(0.95_f64).powf((10_f64 / 30_f64))"); + assert_eq!(nested.to_string(), "(0.95_f32).powf((10_f32 / 30_f32))"); } #[test] @@ -404,7 +433,7 @@ mod tests { function: MathFunction::Sqrt, args: vec![MathExpr::NumericLiteral(NumericLiteral { value: 16.0 })], }; - assert_eq!(sqrt.to_string(), "(16_f64).sqrt()"); + assert_eq!(sqrt.to_string(), "(16_f32).sqrt()"); } #[test] @@ -413,7 +442,7 @@ mod tests { function: MathFunction::Sin, args: vec![MathExpr::NumericLiteral(NumericLiteral { value: 1.57 })], }; - assert_eq!(sin.to_string(), "(1.57_f64).sin()"); + assert_eq!(sin.to_string(), "(1.57_f32).sin()"); } #[test] @@ -422,13 +451,13 @@ mod tests { function: MathFunction::Pi, args: vec![], }; - assert_eq!(pi.to_string(), "std::f64::consts::PI"); + assert_eq!(pi.to_string(), "std::f32::consts::PI"); let e = MathFunctionCallGen { function: MathFunction::E, args: vec![], }; - assert_eq!(e.to_string(), "std::f64::consts::E"); + assert_eq!(e.to_string(), "std::f32::consts::E"); } #[test] @@ -440,7 +469,7 @@ mod tests { }; assert_eq!( edge_prop.to_string(), - "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64())" + "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32())" ); // Test SourceNode context @@ -450,7 +479,7 @@ mod tests { }; assert_eq!( src_prop.to_string(), - "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64())" + "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32())" ); // Test TargetNode context @@ -460,14 +489,14 @@ mod tests { }; assert_eq!( dst_prop.to_string(), - "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f64())" + "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f32())" ); } #[test] fn test_complex_weight_expression() { // Test: MUL(_::{distance}, POW(0.95, DIV(_::{days}, 30))) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (0.95_f64).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f64()) / 30_f64))) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f32()) / 30_f32))) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -496,14 +525,14 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (0.95_f64).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f64()) / 30_f64)))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f32()) / 30_f32)))" ); } #[test] fn test_multi_context_expression() { // Test: MUL(_::{distance}, _::From::{traffic_factor}) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f64())) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f32())) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -520,7 +549,7 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64()))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32()))" ); } } diff --git a/helix-db/src/helixc/generator/queries.rs b/helix-db/src/helixc/generator/queries.rs index 7f856830e..8c30a1dc4 100644 --- a/helix-db/src/helixc/generator/queries.rs +++ b/helix-db/src/helixc/generator/queries.rs @@ -86,7 +86,7 @@ impl Query { for (i, _) in self.hoisted_embedding_calls.iter().enumerate() { let name = EmbedData::name_from_index(i); - writeln!(f, "let {name}: Vec = {name}?;")?; + writeln!(f, "let {name}: Vec = {name}?;")?; } } Ok(()) @@ -168,8 +168,7 @@ impl Query { writeln!( f, " \"{}\": {}", - struct_def.source_variable, - struct_def.source_variable + struct_def.source_variable, struct_def.source_variable )?; } else if struct_def.source_variable.is_empty() { // Object literal - construct from multiple sources @@ -891,8 +890,7 @@ impl Query { writeln!( f, " \"{}\": {}", - struct_def.source_variable, - struct_def.source_variable + struct_def.source_variable, struct_def.source_variable )?; } else if struct_def.is_collection { // Collection - generate mapping code diff --git a/helix-db/src/helixc/generator/schemas.rs b/helix-db/src/helixc/generator/schemas.rs index a6c72059e..7e5ce7665 100644 --- a/helix-db/src/helixc/generator/schemas.rs +++ b/helix-db/src/helixc/generator/schemas.rs @@ -218,7 +218,7 @@ mod tests { }, SchemaProperty { name: "score".to_string(), - field_type: GeneratedType::RustType(RustType::F64), + field_type: GeneratedType::RustType(RustType::F32), default_value: None, is_index: FieldPrefix::Empty, }, @@ -227,7 +227,7 @@ mod tests { let output = format!("{}", schema); assert!(output.contains("pub count: i32,")); - assert!(output.contains("pub score: f64,")); + assert!(output.contains("pub score: f32,")); } // ============================================================================ diff --git a/helix-db/src/helixc/generator/source_steps.rs b/helix-db/src/helixc/generator/source_steps.rs index 07ce8e951..859eebfda 100644 --- a/helix-db/src/helixc/generator/source_steps.rs +++ b/helix-db/src/helixc/generator/source_steps.rs @@ -141,7 +141,7 @@ impl Display for AddV { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "insert_v:: bool>({}, {}, {})", + "insert_v({}, {}, {})", self.vec, self.label, write_properties(&self.properties) diff --git a/helix-db/src/helixc/generator/traversal_steps.rs b/helix-db/src/helixc/generator/traversal_steps.rs index 3fa61f401..356c3f513 100644 --- a/helix-db/src/helixc/generator/traversal_steps.rs +++ b/helix-db/src/helixc/generator/traversal_steps.rs @@ -266,7 +266,9 @@ impl Display for Step { Step::ToV(to_v) => write!(f, "{to_v}"), Step::PropertyFetch(property) => write!(f, "get_property({property})"), Step::ReservedPropertyAccess(prop) => match prop { - ReservedProp::Id => write!(f, "map(|item| Ok(Value::from(uuid_str(item.id, &arena))))"), + ReservedProp::Id => { + write!(f, "map(|item| Ok(Value::from(uuid_str(item.id, &arena))))") + } ReservedProp::Label => write!(f, "map(|item| Ok(Value::from(item.label())))"), // ReservedProp::Version => write!(f, "map(|item| Ok(Value::from(item.version)))"), // ReservedProp::FromNode => write!(f, "map(|item| Ok(Value::from(uuid_str(item.from_node, &arena))))"), @@ -453,7 +455,9 @@ impl Display for WhereRef { | Separator::Empty(Step::PropertyFetch(p)) => prop = Some(p), Separator::Period(Step::ReservedPropertyAccess(rp)) | Separator::Newline(Step::ReservedPropertyAccess(rp)) - | Separator::Empty(Step::ReservedPropertyAccess(rp)) => reserved_prop = Some(rp), + | Separator::Empty(Step::ReservedPropertyAccess(rp)) => { + reserved_prop = Some(rp) + } Separator::Period(Step::BoolOp(op)) | Separator::Newline(Step::BoolOp(op)) | Separator::Empty(Step::BoolOp(op)) => bool_op = Some(op), @@ -720,19 +724,22 @@ impl Display for ShortestPathDijkstras { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", expr )?; } WeightCalculation::Default => { - write!(f, "helix_db::helix_engine::traversal_core::ops::util::paths::default_weight_fn")?; + write!( + f, + "helix_db::helix_engine::traversal_core::ops::util::paths::default_weight_fn" + )?; } } @@ -779,19 +786,22 @@ impl Display for ShortestPathAStar { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}, ", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}, ", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", expr )?; } WeightCalculation::Default => { - write!(f, "helix_db::helix_engine::traversal_core::ops::util::paths::default_weight_fn, ")?; + write!( + f, + "helix_db::helix_engine::traversal_core::ops::util::paths::default_weight_fn, " + )?; } } @@ -824,7 +834,7 @@ pub struct RerankRRF { impl Display for RerankRRF { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.k { - Some(k) => write!(f, "rerank(RRFReranker::with_k({k} as f64).unwrap(), None)"), + Some(k) => write!(f, "rerank(RRFReranker::with_k({k} as f32).unwrap(), None)"), None => write!(f, "rerank(RRFReranker::new(), None)"), } } @@ -843,7 +853,10 @@ impl Display for MMRDistanceMethod { MMRDistanceMethod::Cosine => write!(f, "DistanceMethod::Cosine"), MMRDistanceMethod::Euclidean => write!(f, "DistanceMethod::Euclidean"), MMRDistanceMethod::DotProduct => write!(f, "DistanceMethod::DotProduct"), - MMRDistanceMethod::Identifier(id) => write!(f, "match {id}.as_str() {{ \"cosine\" => DistanceMethod::Cosine, \"euclidean\" => DistanceMethod::Euclidean, \"dotproduct\" => DistanceMethod::DotProduct, _ => DistanceMethod::Cosine }}"), + MMRDistanceMethod::Identifier(id) => write!( + f, + "match {id}.as_str() {{ \"cosine\" => DistanceMethod::Cosine, \"euclidean\" => DistanceMethod::Euclidean, \"dotproduct\" => DistanceMethod::DotProduct, _ => DistanceMethod::Cosine }}" + ), } } } @@ -855,9 +868,15 @@ pub struct RerankMMR { } impl Display for RerankMMR { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let lambda = self.lambda.as_ref().map_or_else(|| "0.7".to_string(), |l| l.to_string()); + let lambda = self + .lambda + .as_ref() + .map_or_else(|| "0.7".to_string(), |l| l.to_string()); match &self.distance { - Some(dist) => write!(f, "rerank(MMRReranker::with_distance({lambda}, {dist}).unwrap(), None)"), + Some(dist) => write!( + f, + "rerank(MMRReranker::with_distance({lambda}, {dist}).unwrap(), None)" + ), None => write!(f, "rerank(MMRReranker::new({lambda}).unwrap(), None)"), } } diff --git a/helix-db/src/helixc/generator/utils.rs b/helix-db/src/helixc/generator/utils.rs index 0710649fe..6084f0aa4 100644 --- a/helix-db/src/helixc/generator/utils.rs +++ b/helix-db/src/helixc/generator/utils.rs @@ -467,7 +467,7 @@ use helix_db::{ traversal_value::TraversalValue, }, types::GraphError, - vector_core::vector::HVector, + vector_core::HVector, }, helix_gateway::{ embedding_providers::{EmbeddingModel, get_embedding_model}, diff --git a/helix-db/src/helixc/parser/object_parse_methods.rs b/helix-db/src/helixc/parser/object_parse_methods.rs index 297b1e743..705c427b6 100644 --- a/helix-db/src/helixc/parser/object_parse_methods.rs +++ b/helix-db/src/helixc/parser/object_parse_methods.rs @@ -36,7 +36,7 @@ impl HelixParser { Rule::float => value_pair .as_str() .parse() - .map(|f| ValueType::new(Value::F64(f), value_pair.loc())) + .map(|f| ValueType::new(Value::F32(f), value_pair.loc())) .map_err(|_| ParserError::from("Invalid float value")), Rule::boolean => Ok(ValueType::new( Value::Boolean(value_pair.as_str() == "true"), @@ -101,7 +101,7 @@ impl HelixParser { }, Rule::float => FieldValue { loc: value_pair.loc(), - value: FieldValueType::Literal(Value::F64( + value: FieldValueType::Literal(Value::F32( value_pair .as_str() .parse() @@ -193,7 +193,7 @@ impl HelixParser { }, Rule::float => FieldValue { loc: value_pair.loc(), - value: FieldValueType::Literal(Value::F64( + value: FieldValueType::Literal(Value::F32( value_pair .as_str() .parse() diff --git a/helix-db/src/helixc/parser/traversal_parse_methods.rs b/helix-db/src/helixc/parser/traversal_parse_methods.rs index 171d399e2..04a59d4e6 100644 --- a/helix-db/src/helixc/parser/traversal_parse_methods.rs +++ b/helix-db/src/helixc/parser/traversal_parse_methods.rs @@ -117,7 +117,7 @@ impl HelixParser { }, Rule::float => ValueType::Literal { value: Value::from( - val.as_str().parse::().map_err(|_| { + val.as_str().parse::().map_err(|_| { ParserError::from("Invalid float value") })?, ), @@ -264,7 +264,7 @@ impl HelixParser { }, Rule::float => ValueType::Literal { value: Value::from( - value_inner.as_str().parse::().map_err(|_| { + value_inner.as_str().parse::().map_err(|_| { ParserError::from("Invalid float value") })?, ), diff --git a/helix-db/src/helixc/parser/types.rs b/helix-db/src/helixc/parser/types.rs index 8125bb8c0..04b8bcf64 100644 --- a/helix-db/src/helixc/parser/types.rs +++ b/helix-db/src/helixc/parser/types.rs @@ -1,5 +1,8 @@ use super::location::Loc; -use crate::{helixc::parser::{errors::ParserError, HelixParser}, protocol::value::Value}; +use crate::{ + helixc::parser::{HelixParser, errors::ParserError}, + protocol::value::Value, +}; use chrono::{DateTime, NaiveDate, Utc}; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -477,7 +480,7 @@ pub enum MathFunction { Sqrt, Ln, Log10, - Log, // Binary: LOG(x, base) + Log, // Binary: LOG(x, base) Exp, Ceil, Floor, @@ -490,7 +493,7 @@ pub enum MathFunction { Asin, Acos, Atan, - Atan2, // Binary: ATAN2(y, x) + Atan2, // Binary: ATAN2(y, x) // Constants (nullary) Pi, @@ -509,16 +512,33 @@ impl MathFunction { pub fn arity(&self) -> usize { match self { MathFunction::Pi | MathFunction::E => 0, - MathFunction::Abs | MathFunction::Sqrt | MathFunction::Ln | - MathFunction::Log10 | MathFunction::Exp | MathFunction::Ceil | - MathFunction::Floor | MathFunction::Round | MathFunction::Sin | - MathFunction::Cos | MathFunction::Tan | MathFunction::Asin | - MathFunction::Acos | MathFunction::Atan | MathFunction::Min | - MathFunction::Max | MathFunction::Sum | MathFunction::Avg | - MathFunction::Count => 1, - MathFunction::Add | MathFunction::Sub | MathFunction::Mul | - MathFunction::Div | MathFunction::Pow | MathFunction::Mod | - MathFunction::Atan2 | MathFunction::Log => 2, + MathFunction::Abs + | MathFunction::Sqrt + | MathFunction::Ln + | MathFunction::Log10 + | MathFunction::Exp + | MathFunction::Ceil + | MathFunction::Floor + | MathFunction::Round + | MathFunction::Sin + | MathFunction::Cos + | MathFunction::Tan + | MathFunction::Asin + | MathFunction::Acos + | MathFunction::Atan + | MathFunction::Min + | MathFunction::Max + | MathFunction::Sum + | MathFunction::Avg + | MathFunction::Count => 1, + MathFunction::Add + | MathFunction::Sub + | MathFunction::Mul + | MathFunction::Div + | MathFunction::Pow + | MathFunction::Mod + | MathFunction::Atan2 + | MathFunction::Log => 2, } } @@ -572,7 +592,7 @@ pub enum ExpressionType { Identifier(String), StringLiteral(String), IntegerLiteral(i32), - FloatLiteral(f64), + FloatLiteral(f32), BooleanLiteral(bool), ArrayLiteral(Vec), Exists(ExistsExpression), @@ -639,7 +659,9 @@ impl Display for ExpressionType { ExpressionType::Or(exprs) => write!(f, "Or({exprs:?})"), ExpressionType::SearchVector(sv) => write!(f, "SearchVector({sv:?})"), ExpressionType::BM25Search(bm25) => write!(f, "BM25Search({bm25:?})"), - ExpressionType::MathFunctionCall(mfc) => write!(f, "{}({:?})", mfc.function.name(), mfc.args), + ExpressionType::MathFunctionCall(mfc) => { + write!(f, "{}({:?})", mfc.function.name(), mfc.args) + } ExpressionType::Empty => write!(f, "Empty"), } } @@ -701,13 +723,13 @@ pub struct OrderBy { #[derive(Debug, Clone)] pub struct Aggregate { pub loc: Loc, - pub properties: Vec + pub properties: Vec, } #[derive(Debug, Clone)] pub struct GroupBy { pub loc: Loc, - pub properties: Vec + pub properties: Vec, } #[derive(Debug, Clone)] @@ -911,7 +933,7 @@ pub enum BooleanOpType { #[derive(Debug, Clone)] pub enum VectorData { - Vector(Vec), + Vector(Vec), Identifier(String), Embed(Embed), } @@ -1069,6 +1091,10 @@ impl From for ValueType { value: Value::I32(i), loc: Loc::empty(), }, + Value::F32(f) => ValueType::Literal { + value: Value::F32(f), + loc: Loc::empty(), + }, Value::F64(f) => ValueType::Literal { value: Value::F64(f), loc: Loc::empty(), diff --git a/helix-db/src/helixc/parser/utils.rs b/helix-db/src/helixc/parser/utils.rs index 49d3e8f29..b34b29464 100644 --- a/helix-db/src/helixc/parser/utils.rs +++ b/helix-db/src/helixc/parser/utils.rs @@ -26,13 +26,13 @@ impl HelixParser { ))), } } - pub(super) fn parse_vec_literal(&self, pair: Pair) -> Result, ParserError> { + pub(super) fn parse_vec_literal(&self, pair: Pair) -> Result, ParserError> { let pairs = pair.into_inner(); let mut vec = Vec::new(); for p in pairs { vec.push( p.as_str() - .parse::() + .parse::() .map_err(|_| ParserError::from("Invalid float value"))?, ); } From 96998ad8fdc091b9d72d2a0bb30a08fb9022d689 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 13:18:21 -0300 Subject: [PATCH 39/48] Cargo fmt --- helix-cli/src/commands/check.rs | 2 - helix-cli/src/commands/metrics.rs | 3 +- helix-cli/src/errors.rs | 19 +- helix-cli/src/main.rs | 4 +- helix-cli/src/metrics_sender.rs | 2 +- helix-cli/src/project.rs | 10 +- helix-cli/src/tests/check_tests.rs | 27 +-- helix-cli/src/tests/compile_tests.rs | 43 ++-- helix-cli/src/tests/init_tests.rs | 5 +- helix-cli/src/tests/mod.rs | 4 +- helix-cli/src/tests/project_tests.rs | 17 +- helix-container/src/main.rs | 2 +- helix-db/benches/bm25_benches.rs | 3 +- helix-db/src/helix_engine/bm25/bm25.rs | 2 +- helix-db/src/helix_engine/bm25/mod.rs | 2 +- helix-db/src/helix_engine/mod.rs | 2 +- .../src/helix_engine/reranker/fusion/mod.rs | 2 +- .../reranker/fusion/score_normalizer.rs | 13 +- .../storage_core/storage_concurrent_tests.rs | 48 +++-- .../storage_core/storage_methods.rs | 6 +- .../concurrency_tests/hnsw_loom_tests.rs | 11 +- .../integration_stress_tests.rs | 6 +- .../tests/concurrency_tests/mod.rs | 2 +- .../traversal_concurrent_tests.rs | 162 ++++++++++----- helix-db/src/helix_engine/tests/mod.rs | 2 +- .../src/helix_engine/tests/storage_tests.rs | 4 +- .../tests/traversal_tests/count_tests.rs | 94 ++++----- .../tests/traversal_tests/filter_tests.rs | 63 ++++-- .../traversal_tests/node_traversal_tests.rs | 122 +++++++---- .../tests/traversal_tests/range_tests.rs | 56 ++--- .../traversal_tests/secondary_index_tests.rs | 30 ++- .../tests/traversal_tests/test_utils.rs | 5 +- .../tests/traversal_tests/update_tests.rs | 29 ++- .../src/helix_engine/traversal_core/config.rs | 2 +- .../ops/bm25/hybrid_search_bm25.rs | 1 - .../traversal_core/ops/bm25/mod.rs | 3 +- .../traversal_core/ops/in_/mod.rs | 2 +- .../traversal_core/ops/source/mod.rs | 2 +- .../traversal_core/ops/source/n_from_index.rs | 18 +- .../traversal_core/ops/source/n_from_type.rs | 8 +- .../traversal_core/ops/util/paths.rs | 25 ++- .../traversal_core/traversal_iter.rs | 10 +- .../vector_core/spaces/simple_avx.rs | 182 +++++++++-------- .../vector_core/spaces/simple_neon.rs | 191 +++++++++--------- .../vector_core/spaces/simple_sse.rs | 176 ++++++++-------- .../src/helix_gateway/builtin/node_by_id.rs | 32 ++- .../helix_gateway/builtin/node_connections.rs | 48 +++-- .../helix_gateway/builtin/nodes_by_label.rs | 37 ++-- .../src/helix_gateway/introspect_schema.rs | 1 - helix-db/src/helix_gateway/router/router.rs | 2 - .../src/helix_gateway/tests/gateway_tests.rs | 4 +- helix-db/src/helix_gateway/tests/mod.rs | 2 +- .../tests/worker_pool_concurrency_tests.rs | 39 ++-- helix-db/src/helixc/analyzer/diagnostic.rs | 6 +- helix-db/src/helixc/analyzer/error_codes.rs | 3 +- .../analyzer/methods/exclude_validation.rs | 12 +- .../analyzer/methods/graph_step_validation.rs | 68 +++---- .../analyzer/methods/migration_validation.rs | 119 +++++++---- .../analyzer/methods/object_validation.rs | 16 +- .../analyzer/methods/statement_validation.rs | 85 ++++++-- helix-db/src/helixc/analyzer/mod.rs | 8 +- helix-db/src/helixc/analyzer/types.rs | 14 +- helix-db/src/helixc/generator/migrations.rs | 8 +- helix-db/src/helixc/generator/statements.rs | 18 +- .../parser/creation_step_parse_methods.rs | 2 +- helix-db/src/helixc/parser/errors.rs | 2 +- .../helixc/parser/expression_parse_methods.rs | 16 +- .../helixc/parser/graph_step_parse_methods.rs | 93 +++++---- .../src/helixc/parser/query_parse_methods.rs | 43 ++-- .../parser/return_value_parse_methods.rs | 3 +- .../src/helixc/parser/schema_parse_methods.rs | 32 ++- helix-db/src/lib.rs | 2 +- .../src/protocol/custom_serde/edge_serde.rs | 2 +- .../custom_serde/error_handling_tests.rs | 2 +- .../src/protocol/custom_serde/node_serde.rs | 8 +- helix-db/src/protocol/custom_serde/tests.rs | 175 +++++++++++----- helix-db/src/protocol/format.rs | 3 +- helix-db/src/protocol/mod.rs | 2 +- helix-db/src/protocol/request.rs | 5 +- helix-db/src/utils/id.rs | 18 +- helix-db/src/utils/items.rs | 12 +- helix-db/src/utils/label_hash.rs | 56 +++-- helix-db/src/utils/tqdm.rs | 5 +- helix-macros/src/lib.rs | 23 +-- hql-tests/src/main.rs | 11 +- metrics/src/events.rs | 2 +- metrics/src/lib.rs | 2 - 87 files changed, 1426 insertions(+), 1037 deletions(-) diff --git a/helix-cli/src/commands/check.rs b/helix-cli/src/commands/check.rs index 217209c0e..f8f7fbb2e 100644 --- a/helix-cli/src/commands/check.rs +++ b/helix-cli/src/commands/check.rs @@ -46,8 +46,6 @@ async fn check_all_instances(project: &ProjectContext) -> Result<()> { Ok(()) } - - /// Validate project syntax by parsing queries and schema (similar to build.rs but without generating files) fn validate_project_syntax(project: &ProjectContext) -> Result<()> { print_status("VALIDATE", "Parsing and validating Helix queries"); diff --git a/helix-cli/src/commands/metrics.rs b/helix-cli/src/commands/metrics.rs index 51dbf68df..4413b9e98 100644 --- a/helix-cli/src/commands/metrics.rs +++ b/helix-cli/src/commands/metrics.rs @@ -90,7 +90,8 @@ async fn show_metrics_status() -> Result<()> { Ok(()) } -static EMAIL_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap()); +static EMAIL_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap()); fn ask_for_email() -> String { print_line("Please enter your email address:"); diff --git a/helix-cli/src/errors.rs b/helix-cli/src/errors.rs index 38d0aaeb5..6b5879f3a 100644 --- a/helix-cli/src/errors.rs +++ b/helix-cli/src/errors.rs @@ -92,7 +92,6 @@ impl CliError { self } - pub fn render(&self) -> String { let mut output = String::new(); @@ -186,36 +185,30 @@ pub type CliResult = Result; // Convenience functions for common error patterns with error codes #[allow(unused)] pub fn config_error>(message: S) -> CliError { - CliError::new(message) - .with_hint("run `helix init` if you need to create a new project") + CliError::new(message).with_hint("run `helix init` if you need to create a new project") } #[allow(unused)] pub fn file_error>(message: S, file_path: S) -> CliError { - CliError::new(message) - .with_file_path(file_path) + CliError::new(message).with_file_path(file_path) } #[allow(unused)] pub fn docker_error>(message: S) -> CliError { - CliError::new(message) - .with_hint("ensure Docker is running and accessible") + CliError::new(message).with_hint("ensure Docker is running and accessible") } #[allow(unused)] pub fn network_error>(message: S) -> CliError { - CliError::new(message) - .with_hint("check your internet connection and try again") + CliError::new(message).with_hint("check your internet connection and try again") } #[allow(unused)] pub fn project_error>(message: S) -> CliError { - CliError::new(message) - .with_hint("ensure you're in a valid helix project directory") + CliError::new(message).with_hint("ensure you're in a valid helix project directory") } #[allow(unused)] pub fn cloud_error>(message: S) -> CliError { - CliError::new(message) - .with_hint("run `helix auth login` to authenticate with Helix Cloud") + CliError::new(message).with_hint("run `helix auth login` to authenticate with Helix Cloud") } diff --git a/helix-cli/src/main.rs b/helix-cli/src/main.rs index f4ad38b9d..3b35245e4 100644 --- a/helix-cli/src/main.rs +++ b/helix-cli/src/main.rs @@ -204,7 +204,9 @@ async fn main() -> Result<()> { port, dry_run, no_backup, - } => commands::migrate::run(path, queries_dir, instance_name, port, dry_run, no_backup).await, + } => { + commands::migrate::run(path, queries_dir, instance_name, port, dry_run, no_backup).await + } }; // Shutdown metrics sender diff --git a/helix-cli/src/metrics_sender.rs b/helix-cli/src/metrics_sender.rs index d184b4d23..3c4fd0f48 100644 --- a/helix-cli/src/metrics_sender.rs +++ b/helix-cli/src/metrics_sender.rs @@ -1,6 +1,6 @@ use chrono::{Local, NaiveDate}; use dirs::home_dir; -use eyre::{eyre, OptionExt, Result}; +use eyre::{OptionExt, Result, eyre}; use flume::{Receiver, Sender, unbounded}; use helix_metrics::events::{ CompileEvent, DeployCloudEvent, DeployLocalEvent, EventData, EventType, RawEvent, diff --git a/helix-cli/src/project.rs b/helix-cli/src/project.rs index 737b21666..e2418c078 100644 --- a/helix-cli/src/project.rs +++ b/helix-cli/src/project.rs @@ -93,7 +93,10 @@ fn find_project_root(start: &Path) -> Result { let error = crate::errors::config_error("found v1 project configuration") .with_file_path(v1_config_path.display().to_string()) .with_context("This project uses the old v1 configuration format") - .with_hint(format!("Run 'helix migrate --path \"{}\"' to migrate this project to v2 format", current.display())); + .with_hint(format!( + "Run 'helix migrate --path \"{}\"' to migrate this project to v2 format", + current.display() + )); return Err(eyre!("{}", error.render())); } @@ -105,7 +108,10 @@ fn find_project_root(start: &Path) -> Result { let error = crate::errors::config_error("project configuration not found") .with_file_path(start.display().to_string()) - .with_context(format!("searched from {} up to filesystem root", start.display())); + .with_context(format!( + "searched from {} up to filesystem root", + start.display() + )); Err(eyre!("{}", error.render())) } diff --git a/helix-cli/src/tests/check_tests.rs b/helix-cli/src/tests/check_tests.rs index f87cdb2a7..65eb725f1 100644 --- a/helix-cli/src/tests/check_tests.rs +++ b/helix-cli/src/tests/check_tests.rs @@ -47,8 +47,7 @@ E::Likes { To: Post, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create valid queries.hx let queries_content = r#" @@ -60,8 +59,7 @@ QUERY GetUserPosts(user_id: ID) => posts <- N(user_id)::Out RETURN posts "#; - fs::write(queries_dir.join("queries.hx"), queries_content) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), queries_content).expect("Failed to write queries.hx"); (temp_dir, project_path) } @@ -91,8 +89,7 @@ QUERY GetUser(user_id: ID) => user <- N(user_id) RETURN user "#; - fs::write(queries_dir.join("queries.hx"), queries_content) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), queries_content).expect("Failed to write queries.hx"); (temp_dir, project_path) } @@ -122,8 +119,7 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create queries.hx with invalid syntax let invalid_queries = r#" @@ -131,8 +127,7 @@ QUERY InvalidQuery { this is not valid helix syntax!!! } "#; - fs::write(queries_dir.join("queries.hx"), invalid_queries) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), invalid_queries).expect("Failed to write queries.hx"); (temp_dir, project_path) } @@ -268,8 +263,7 @@ E::Follows { To: User, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -349,8 +343,7 @@ E::Follows { To: User, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -387,8 +380,7 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create additional schema in another file let more_schema = r#" @@ -447,8 +439,7 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); let _guard = std::env::set_current_dir(&project_path); diff --git a/helix-cli/src/tests/compile_tests.rs b/helix-cli/src/tests/compile_tests.rs index ec4140834..5186f8238 100644 --- a/helix-cli/src/tests/compile_tests.rs +++ b/helix-cli/src/tests/compile_tests.rs @@ -40,8 +40,7 @@ E::Authored { To: Post, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create valid queries.hx let queries_content = r#" @@ -53,8 +52,7 @@ QUERY GetUserPosts(user_id: ID) => posts <- N(user_id)::Out RETURN posts "#; - fs::write(queries_dir.join("queries.hx"), queries_content) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), queries_content).expect("Failed to write queries.hx"); (temp_dir, project_path) } @@ -115,10 +113,7 @@ async fn test_compile_with_explicit_project_path() { // Check that compiled output files were created let query_file = project_path.join("queries.rs"); - assert!( - query_file.exists(), - "Compiled queries.rs should be created" - ); + assert!(query_file.exists(), "Compiled queries.rs should be created"); } #[tokio::test] @@ -145,8 +140,7 @@ QUERY GetUser(user_id: ID) => user <- N(user_id) RETURN user "#; - fs::write(queries_dir.join("queries.hx"), queries_content) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), queries_content).expect("Failed to write queries.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -184,16 +178,14 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create queries with invalid syntax let invalid_queries = r#" QUERY InvalidQuery this is not valid helix syntax!!! "#; - fs::write(queries_dir.join("queries.hx"), invalid_queries) - .expect("Failed to write queries.hx"); + fs::write(queries_dir.join("queries.hx"), invalid_queries).expect("Failed to write queries.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -249,8 +241,7 @@ E::Follows { To: User, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -294,8 +285,7 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); // Create additional schema in another file let more_schema = r#" @@ -330,10 +320,7 @@ QUERY GetUser(id: ID) => // Check that compiled output files were created let query_file = project_path.join("queries.rs"); - assert!( - query_file.exists(), - "Compiled queries.rs should be created" - ); + assert!(query_file.exists(), "Compiled queries.rs should be created"); } #[tokio::test] @@ -361,8 +348,7 @@ N::User { name: String, } "#; - fs::write(queries_dir.join("schema.hx"), schema_content) - .expect("Failed to write schema.hx"); + fs::write(queries_dir.join("schema.hx"), schema_content).expect("Failed to write schema.hx"); let _guard = std::env::set_current_dir(&project_path); @@ -375,10 +361,7 @@ N::User { // Check that compiled output files were created let query_file = project_path.join("queries.rs"); - assert!( - query_file.exists(), - "Compiled queries.rs should be created" - ); + assert!(query_file.exists(), "Compiled queries.rs should be created"); } #[tokio::test] @@ -400,7 +383,9 @@ async fn test_compile_creates_all_required_files() { "Generated queries.rs should not be empty" ); assert!( - query_content.contains("pub") || query_content.contains("use") || query_content.contains("impl"), + query_content.contains("pub") + || query_content.contains("use") + || query_content.contains("impl"), "Generated queries.rs should contain Rust code" ); } diff --git a/helix-cli/src/tests/init_tests.rs b/helix-cli/src/tests/init_tests.rs index 7f0e11d1f..54b09e8bc 100644 --- a/helix-cli/src/tests/init_tests.rs +++ b/helix-cli/src/tests/init_tests.rs @@ -166,7 +166,10 @@ async fn test_init_creates_directory_if_not_exists() { let project_path = temp_dir.path().join("new_project_dir"); // Directory should not exist yet - assert!(!project_path.exists(), "Project directory should not exist initially"); + assert!( + !project_path.exists(), + "Project directory should not exist initially" + ); let result = run( Some(project_path.to_str().unwrap().to_string()), diff --git a/helix-cli/src/tests/mod.rs b/helix-cli/src/tests/mod.rs index c321e6c36..f7511ed98 100644 --- a/helix-cli/src/tests/mod.rs +++ b/helix-cli/src/tests/mod.rs @@ -1,10 +1,10 @@ // CLI test modules #[cfg(test)] -pub mod init_tests; -#[cfg(test)] pub mod check_tests; #[cfg(test)] pub mod compile_tests; +#[cfg(test)] +pub mod init_tests; // #[cfg(test)] // pub mod build_tests; // #[cfg(test)] diff --git a/helix-cli/src/tests/project_tests.rs b/helix-cli/src/tests/project_tests.rs index 12d2df016..58648e553 100644 --- a/helix-cli/src/tests/project_tests.rs +++ b/helix-cli/src/tests/project_tests.rs @@ -1,5 +1,5 @@ use crate::config::HelixConfig; -use crate::project::{get_helix_cache_dir, ProjectContext}; +use crate::project::{ProjectContext, get_helix_cache_dir}; use std::fs; use std::path::PathBuf; use tempfile::TempDir; @@ -180,7 +180,10 @@ fn test_project_context_ensure_instance_dirs() { assert!(!workspace.exists(), "Workspace should not exist initially"); assert!(!volume.exists(), "Volume should not exist initially"); - assert!(!container.exists(), "Container dir should not exist initially"); + assert!( + !container.exists(), + "Container dir should not exist initially" + ); let result = context.ensure_instance_dirs("test-instance"); assert!(result.is_ok(), "Should create instance directories"); @@ -242,7 +245,10 @@ fn test_project_context_with_custom_queries_path() { fs::create_dir_all(project_path.join(".helix")).expect("Failed to create .helix"); let result = ProjectContext::find_and_load(Some(&project_path)); - assert!(result.is_ok(), "Should load project with custom queries path"); + assert!( + result.is_ok(), + "Should load project with custom queries path" + ); let context = result.unwrap(); assert_eq!( @@ -297,5 +303,8 @@ fn test_find_project_root_stops_at_filesystem_root() { fs::create_dir_all(&deep_path).expect("Failed to create deep path"); let result = ProjectContext::find_and_load(Some(&deep_path)); - assert!(result.is_err(), "Should fail after reaching filesystem root"); + assert!( + result.is_err(), + "Should fail after reaching filesystem root" + ); } diff --git a/helix-container/src/main.rs b/helix-container/src/main.rs index 666aeb52b..8da264c17 100644 --- a/helix-container/src/main.rs +++ b/helix-container/src/main.rs @@ -11,7 +11,7 @@ use helix_db::helix_gateway::{ }; use std::{collections::HashMap, sync::Arc}; use tracing::info; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; +use tracing_subscriber::{Layer, layer::SubscriberExt, util::SubscriberInitExt}; mod queries; diff --git a/helix-db/benches/bm25_benches.rs b/helix-db/benches/bm25_benches.rs index d6b86ad64..83a6ceb61 100644 --- a/helix-db/benches/bm25_benches.rs +++ b/helix-db/benches/bm25_benches.rs @@ -3,7 +3,7 @@ mod tests { use helix_db::{ debug_println, - helix_engine::bm25::bm25::{HBM25Config, BM25}, + helix_engine::bm25::bm25::{BM25, HBM25Config}, utils::{id::v6_uuid, tqdm::tqdm}, }; @@ -155,4 +155,3 @@ mod tests { } } } - diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index f3145a216..88ca28474 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -445,7 +445,7 @@ impl HybridSearch for HelixGraphStorage { // correct_score = alpha * bm25_score + (1.0 - alpha) * vector_score if let Some(vector_results) = vector_results? { for (doc_id, score) in vector_results { - let similarity = 1.0 / (1.0 + score) ; + let similarity = 1.0 / (1.0 + score); combined_scores .entry(doc_id) .and_modify(|existing_score| *existing_score += (1.0 - alpha) * similarity) diff --git a/helix-db/src/helix_engine/bm25/mod.rs b/helix-db/src/helix_engine/bm25/mod.rs index 74a06cb09..b033aefcb 100644 --- a/helix-db/src/helix_engine/bm25/mod.rs +++ b/helix-db/src/helix_engine/bm25/mod.rs @@ -1,4 +1,4 @@ pub mod bm25; #[cfg(test)] -pub mod bm25_tests; \ No newline at end of file +pub mod bm25_tests; diff --git a/helix-db/src/helix_engine/mod.rs b/helix-db/src/helix_engine/mod.rs index 516c631ce..621590efd 100644 --- a/helix-db/src/helix_engine/mod.rs +++ b/helix-db/src/helix_engine/mod.rs @@ -1,8 +1,8 @@ pub mod bm25; -pub mod traversal_core; pub mod macros; pub mod reranker; pub mod storage_core; +pub mod traversal_core; pub mod types; pub mod vector_core; diff --git a/helix-db/src/helix_engine/reranker/fusion/mod.rs b/helix-db/src/helix_engine/reranker/fusion/mod.rs index 7e291f285..2d379ae47 100644 --- a/helix-db/src/helix_engine/reranker/fusion/mod.rs +++ b/helix-db/src/helix_engine/reranker/fusion/mod.rs @@ -9,4 +9,4 @@ pub mod score_normalizer; pub use mmr::{DistanceMethod, MMRReranker}; pub use rrf::RRFReranker; -pub use score_normalizer::{normalize_scores, NormalizationMethod}; +pub use score_normalizer::{NormalizationMethod, normalize_scores}; diff --git a/helix-db/src/helix_engine/reranker/fusion/score_normalizer.rs b/helix-db/src/helix_engine/reranker/fusion/score_normalizer.rs index 4c3e60cb6..6dbae3dae 100644 --- a/helix-db/src/helix_engine/reranker/fusion/score_normalizer.rs +++ b/helix-db/src/helix_engine/reranker/fusion/score_normalizer.rs @@ -39,12 +39,8 @@ pub fn normalize_scores(scores: &[f64], method: NormalizationMethod) -> Reranker /// Min-Max normalization: scales scores to [0, 1] range. fn normalize_minmax(scores: &[f64]) -> RerankerResult> { - let min = scores - .iter() - .fold(f64::INFINITY, |a, &b| a.min(b)); - let max = scores - .iter() - .fold(f64::NEG_INFINITY, |a, &b| a.max(b)); + let min = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)); + let max = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)); let range = max - min; @@ -53,10 +49,7 @@ fn normalize_minmax(scores: &[f64]) -> RerankerResult> { return Ok(vec![0.5; scores.len()]); } - Ok(scores - .iter() - .map(|&score| (score - min) / range) - .collect()) + Ok(scores.iter().map(|&score| (score - min) / range).collect()) } /// Z-score normalization: centers scores around mean with unit variance. diff --git a/helix-db/src/helix_engine/storage_core/storage_concurrent_tests.rs b/helix-db/src/helix_engine/storage_core/storage_concurrent_tests.rs index 3035028f0..fb4ea50ac 100644 --- a/helix-db/src/helix_engine/storage_core/storage_concurrent_tests.rs +++ b/helix-db/src/helix_engine/storage_core/storage_concurrent_tests.rs @@ -11,15 +11,14 @@ /// - Drop operations are multi-step (not atomic) - could leave orphans /// - LMDB provides single-writer guarantee but needs validation /// - MVCC snapshot isolation needs verification - use std::sync::{Arc, Barrier}; use std::thread; use tempfile::TempDir; use crate::helix_engine::storage_core::HelixGraphStorage; -use crate::helix_engine::traversal_core::config::Config; use crate::helix_engine::storage_core::version_info::VersionInfo; -use crate::utils::items::{Node, Edge}; +use crate::helix_engine::traversal_core::config::Config; +use crate::utils::items::{Edge, Node}; use bumpalo::Bump; use uuid::Uuid; @@ -70,7 +69,10 @@ fn test_concurrent_node_creation() { properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); wtxn.commit().unwrap(); } }) @@ -114,7 +116,10 @@ fn test_concurrent_edge_creation() { version: 1, properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); } wtxn.commit().unwrap(); } @@ -122,7 +127,8 @@ fn test_concurrent_edge_creation() { // Get node IDs let node_ids: Vec = { let rtxn = storage.graph_env.read_txn().unwrap(); - storage.nodes_db + storage + .nodes_db .iter(&rtxn) .unwrap() .map(|result| { @@ -164,7 +170,10 @@ fn test_concurrent_edge_creation() { properties: None, }; - storage.edges_db.put(&mut wtxn, &edge.id, &edge.to_bincode_bytes().unwrap()).unwrap(); + storage + .edges_db + .put(&mut wtxn, &edge.id, &edge.to_bincode_bytes().unwrap()) + .unwrap(); wtxn.commit().unwrap(); } }) @@ -209,7 +218,10 @@ fn test_concurrent_node_reads() { version: 1, properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); } wtxn.commit().unwrap(); } @@ -269,7 +281,10 @@ fn test_concurrent_node_reads() { version: 1, properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); wtxn.commit().unwrap(); thread::sleep(std::time::Duration::from_millis(2)); @@ -314,7 +329,10 @@ fn test_transaction_isolation_storage() { version: 1, properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); } wtxn.commit().unwrap(); } @@ -338,7 +356,10 @@ fn test_transaction_isolation_storage() { version: 1, properties: None, }; - storage_clone.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage_clone + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); wtxn.commit().unwrap(); } }); @@ -394,7 +415,10 @@ fn test_write_transaction_serialization() { properties: None, }; - storage.nodes_db.put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()).unwrap(); + storage + .nodes_db + .put(&mut wtxn, &node.id, &node.to_bincode_bytes().unwrap()) + .unwrap(); // Simulate some work during transaction thread::sleep(std::time::Duration::from_micros(100)); diff --git a/helix-db/src/helix_engine/storage_core/storage_methods.rs b/helix-db/src/helix_engine/storage_core/storage_methods.rs index 1d009ca69..5f10e1c7a 100644 --- a/helix-db/src/helix_engine/storage_core/storage_methods.rs +++ b/helix-db/src/helix_engine/storage_core/storage_methods.rs @@ -34,12 +34,12 @@ pub trait StorageMethods { fn drop_node(&self, txn: &mut RwTxn, id: &u128) -> Result<(), GraphError>; /// Removes the following from the storage engine: - /// - The given edge + /// - The given edge /// - All incoming and outgoing mappings for that edge fn drop_edge(&self, txn: &mut RwTxn, id: &u128) -> Result<(), GraphError>; /// Sets the `deleted` field of a vector to true - /// - /// NOTE: The vector is not ACTUALLY deleted and is still present in the db. + /// + /// NOTE: The vector is not ACTUALLY deleted and is still present in the db. fn drop_vector(&self, txn: &mut RwTxn, id: &u128) -> Result<(), GraphError>; } diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs index 709c0b9f5..f542bdf58 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_loom_tests.rs @@ -10,9 +10,6 @@ /// /// NOTE: Loom tests are expensive - they explore all possible execution orderings. /// Keep the problem space small (few operations, few threads). - - - use loom::sync::Arc; use loom::sync::atomic::{AtomicU64, Ordering}; use loom::thread; @@ -84,9 +81,7 @@ fn loom_entry_point_read_write_race() { }); // Reader thread: Reads entry point (might see 0 or 12345) - let reader = thread::spawn(move || { - reader_entry.load(Ordering::SeqCst) - }); + let reader = thread::spawn(move || reader_entry.load(Ordering::SeqCst)); writer.join().unwrap(); let read_value = reader.join().unwrap(); @@ -295,9 +290,7 @@ fn loom_two_writers_one_reader() { }); // Reader: Read value (should see 0, 1, or 2) - let reader = thread::spawn(move || { - r_value.load(Ordering::SeqCst) - }); + let reader = thread::spawn(move || r_value.load(Ordering::SeqCst)); w1.join().unwrap(); w2.join().unwrap(); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/integration_stress_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/integration_stress_tests.rs index 233cadb5c..a2b6108aa 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/integration_stress_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/integration_stress_tests.rs @@ -86,7 +86,8 @@ fn test_stress_mixed_read_write_operations() { G::new_mut(&storage, &arena, &mut wtxn) .add_edge("connects", None, id1, id2, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); write_ops.fetch_add(1, Ordering::Relaxed); @@ -196,7 +197,8 @@ fn test_stress_rapid_graph_growth() { let root_idx = local_count % root_ids.len(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("child_of", None, root_ids[root_idx], new_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); write_count.fetch_add(1, Ordering::Relaxed); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/mod.rs b/helix-db/src/helix_engine/tests/concurrency_tests/mod.rs index 801e05b7d..6d2706c6e 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/mod.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/mod.rs @@ -2,4 +2,4 @@ pub mod hnsw_concurrent_tests; pub mod hnsw_loom_tests; pub mod integration_stress_tests; -pub mod traversal_concurrent_tests; \ No newline at end of file +pub mod traversal_concurrent_tests; diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs index 03ece0217..7cfdaa7bf 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/traversal_concurrent_tests.rs @@ -1,3 +1,4 @@ +use bumpalo::Bump; /// Concurrent access tests for Traversal Operations /// /// This test suite validates thread safety and concurrent operation correctness @@ -13,22 +14,18 @@ /// - MVCC ensures readers see consistent graph snapshots /// - Edge creation/deletion doesn't corrupt graph topology /// - No race conditions in neighbor list updates - use std::sync::{Arc, Barrier}; use std::thread; -use bumpalo::Bump; use tempfile::TempDir; use crate::helix_engine::storage_core::HelixGraphStorage; use crate::helix_engine::traversal_core::config::Config; use crate::helix_engine::traversal_core::ops::g::G; +use crate::helix_engine::traversal_core::ops::in_::in_::InAdapter; +use crate::helix_engine::traversal_core::ops::out::out::OutAdapter; use crate::helix_engine::traversal_core::ops::source::{ - add_n::AddNAdapter, - add_e::AddEAdapter, - n_from_id::NFromIdAdapter, + add_e::AddEAdapter, add_n::AddNAdapter, n_from_id::NFromIdAdapter, }; -use crate::helix_engine::traversal_core::ops::out::out::OutAdapter; -use crate::helix_engine::traversal_core::ops::in_::in_::InAdapter; /// Setup storage for concurrent testing fn setup_concurrent_storage() -> (TempDir, Arc) { @@ -69,7 +66,8 @@ fn test_concurrent_node_additions() { let label = format!("person_t{}_n{}", thread_id, i); G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); wtxn.commit().unwrap(); } @@ -112,7 +110,8 @@ fn test_concurrent_edge_additions() { let label = format!("node_{}", i); G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id() }) .collect(); @@ -144,8 +143,15 @@ fn test_concurrent_edge_additions() { let label = format!("knows_t{}_e{}", thread_id, i); G::new_mut(&storage, &arena, &mut wtxn) - .add_edge(&label, None, node_ids[source_idx], node_ids[target_idx], false) - .collect_to_obj().unwrap(); + .add_edge( + &label, + None, + node_ids[source_idx], + node_ids[target_idx], + false, + ) + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); } @@ -184,7 +190,8 @@ fn test_concurrent_reads_during_writes() { let root = G::new_mut(&storage, &arena, &mut wtxn) .add_n("root", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); // Add initial neighbors @@ -192,12 +199,14 @@ fn test_concurrent_reads_during_writes() { let label = format!("initial_{}", i); let neighbor_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("connects", None, root, neighbor_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); } wtxn.commit().unwrap(); @@ -228,7 +237,8 @@ fn test_concurrent_reads_during_writes() { let neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(&root_id) .out_node("connects") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Should see at least initial neighbors assert!( @@ -261,12 +271,14 @@ fn test_concurrent_reads_during_writes() { let label = format!("writer_{}_node_{}", writer_id, i); let new_node_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("connects", None, root_id, new_node_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); @@ -286,7 +298,8 @@ fn test_concurrent_reads_during_writes() { let final_neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(&root_id) .out_node("connects") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let expected_count = 5 + (num_writers * 10); assert_eq!( @@ -313,19 +326,22 @@ fn test_traversal_snapshot_isolation() { let root = G::new_mut(&storage, &arena, &mut wtxn) .add_n("root", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); for i in 0..5 { let label = format!("node_{}", i); let node_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("links", None, root, node_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); } wtxn.commit().unwrap(); @@ -338,7 +354,8 @@ fn test_traversal_snapshot_isolation() { let initial_neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(&root_id) .out_node("links") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let initial_count = initial_neighbors.len(); assert_eq!(initial_count, 5); @@ -352,12 +369,14 @@ fn test_traversal_snapshot_isolation() { let label = format!("new_node_{}", i); let new_id = G::new_mut(&storage_clone, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage_clone, &arena, &mut wtxn) .add_edge("links", None, root_id, new_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); } @@ -370,7 +389,8 @@ fn test_traversal_snapshot_isolation() { let current_neighbors = G::new(&storage, &rtxn, &arena2) .n_from_id(&root_id) .out_node("links") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!( current_neighbors.len(), @@ -388,7 +408,8 @@ fn test_traversal_snapshot_isolation() { let final_neighbors = G::new(&storage, &rtxn_new, &arena3) .n_from_id(&root_id) .out_node("links") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(final_neighbors.len(), 15); } @@ -410,7 +431,8 @@ fn test_concurrent_bidirectional_traversals() { let label = format!("source_{}", i); G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id() }) .collect(); @@ -420,7 +442,8 @@ fn test_concurrent_bidirectional_traversals() { let label = format!("target_{}", i); G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id() }) .collect(); @@ -430,7 +453,8 @@ fn test_concurrent_bidirectional_traversals() { for target_id in &targets { G::new_mut(&storage, &arena, &mut wtxn) .add_edge("points_to", None, *source_id, *target_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); } } @@ -463,7 +487,8 @@ fn test_concurrent_bidirectional_traversals() { let neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(source_id) .out_node("points_to") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(neighbors.len(), 5, "Source should have 5 outgoing edges"); } } else { @@ -472,7 +497,8 @@ fn test_concurrent_bidirectional_traversals() { let neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(target_id) .in_node("points_to") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(neighbors.len(), 5, "Target should have 5 incoming edges"); } } @@ -503,7 +529,8 @@ fn test_concurrent_multi_hop_traversals() { let root = G::new_mut(&storage, &arena, &mut wtxn) .add_n("root", None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); // Create level 1 nodes @@ -512,12 +539,14 @@ fn test_concurrent_multi_hop_traversals() { let label = format!("level1_{}", i); let id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("to_l1", None, root, id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); id }) @@ -529,12 +558,14 @@ fn test_concurrent_multi_hop_traversals() { let label = format!("level2_{}", i); let l2_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("to_l2", None, l1_id, l2_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); } } @@ -561,7 +592,8 @@ fn test_concurrent_multi_hop_traversals() { let level1 = G::new(&storage, &rtxn, &arena) .n_from_id(&root_id) .out_node("to_l1") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(level1.len(), 3, "Should have 3 level1 nodes"); // For each level1, traverse to level2 @@ -570,7 +602,8 @@ fn test_concurrent_multi_hop_traversals() { let level2 = G::new(&storage, &rtxn, &arena2) .n_from_id(&l1_node.id()) .out_node("to_l2") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(level2.len(), 2, "Each level1 should have 2 level2 nodes"); } @@ -615,17 +648,20 @@ fn test_concurrent_graph_topology_consistency() { let node1_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label1, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let node2_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label2, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("connects", None, node1_id, node2_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); } @@ -654,17 +690,26 @@ fn test_concurrent_graph_topology_consistency() { // Verify all edges point to valid nodes for result in storage.edges_db.iter(&rtxn).unwrap() { let (edge_id, edge_bytes) = result.unwrap(); - let edge = crate::utils::items::Edge::from_bincode_bytes(edge_id, edge_bytes, &arena).unwrap(); + let edge = + crate::utils::items::Edge::from_bincode_bytes(edge_id, edge_bytes, &arena).unwrap(); // Verify source exists assert!( - storage.nodes_db.get(&rtxn, &edge.from_node).unwrap().is_some(), + storage + .nodes_db + .get(&rtxn, &edge.from_node) + .unwrap() + .is_some(), "Edge source node not found" ); // Verify target exists assert!( - storage.nodes_db.get(&rtxn, &edge.to_node).unwrap().is_some(), + storage + .nodes_db + .get(&rtxn, &edge.to_node) + .unwrap() + .is_some(), "Edge target node not found" ); } @@ -688,7 +733,8 @@ fn test_stress_concurrent_mixed_operations() { let label = format!("root_{}", i); G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id() }) .collect(); @@ -717,13 +763,15 @@ fn test_stress_concurrent_mixed_operations() { let label = format!("w{}_n{}", writer_id, write_count); let new_id = G::new_mut(&storage, &arena, &mut wtxn) .add_n(&label, None, None) - .collect::,_>>().unwrap()[0] + .collect::, _>>() + .unwrap()[0] .id(); let root_idx = write_count % root_ids.len(); G::new_mut(&storage, &arena, &mut wtxn) .add_edge("links", None, root_ids[root_idx], new_id, false) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); wtxn.commit().unwrap(); write_count += 1; @@ -747,7 +795,8 @@ fn test_stress_concurrent_mixed_operations() { let _neighbors = G::new(&storage, &rtxn, &arena) .n_from_id(root_id) .out_node("links") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); read_count += 1; } } @@ -770,9 +819,20 @@ fn test_stress_concurrent_mixed_operations() { let total_writes: usize = write_counts.iter().sum(); let total_reads: usize = read_counts.iter().sum(); - println!("Stress test: {} writes, {} reads in {:?}", total_writes, total_reads, duration); + println!( + "Stress test: {} writes, {} reads in {:?}", + total_writes, total_reads, duration + ); // Should process many operations - assert!(total_writes > 50, "Should perform many writes, got {}", total_writes); - assert!(total_reads > 100, "Should perform many reads, got {}", total_reads); + assert!( + total_writes > 50, + "Should perform many writes, got {}", + total_writes + ); + assert!( + total_reads > 100, + "Should perform many reads, got {}", + total_reads + ); } diff --git a/helix-db/src/helix_engine/tests/mod.rs b/helix-db/src/helix_engine/tests/mod.rs index 0ceecf9d0..5a7915aab 100644 --- a/helix-db/src/helix_engine/tests/mod.rs +++ b/helix-db/src/helix_engine/tests/mod.rs @@ -1,6 +1,6 @@ pub mod traversal_tests; pub mod vector_tests; // pub mod bm25_tests; +pub mod concurrency_tests; pub mod hnsw_tests; pub mod storage_tests; -pub mod concurrency_tests; \ No newline at end of file diff --git a/helix-db/src/helix_engine/tests/storage_tests.rs b/helix-db/src/helix_engine/tests/storage_tests.rs index 8fd061c29..b730feb64 100644 --- a/helix-db/src/helix_engine/tests/storage_tests.rs +++ b/helix-db/src/helix_engine/tests/storage_tests.rs @@ -1,5 +1,7 @@ use crate::helix_engine::{ - storage_core::{HelixGraphStorage, storage_methods::DBMethods, version_info::VersionInfo, StorageConfig}, + storage_core::{ + HelixGraphStorage, StorageConfig, storage_methods::DBMethods, version_info::VersionInfo, + }, traversal_core::config::Config, }; use tempfile::TempDir; diff --git a/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs index f895da874..7aa752bc3 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/count_tests.rs @@ -1,27 +1,21 @@ use std::{sync::Arc, time::Duration}; -use crate::{ - helix_engine::{ - storage_core::HelixGraphStorage, - traversal_core::{ - ops::{ - g::G, - out::out::OutAdapter, - source::{ - add_e::AddEAdapter, - add_n::AddNAdapter, - n_from_id::NFromIdAdapter, - n_from_type::NFromTypeAdapter, - }, - util::{count::CountAdapter, filter_ref::FilterRefAdapter, range::RangeAdapter}, - }, +use crate::helix_engine::{ + storage_core::HelixGraphStorage, + traversal_core::ops::{ + g::G, + out::out::OutAdapter, + source::{ + add_e::AddEAdapter, add_n::AddNAdapter, n_from_id::NFromIdAdapter, + n_from_type::NFromTypeAdapter, }, + util::{count::CountAdapter, filter_ref::FilterRefAdapter, range::RangeAdapter}, }, }; +use bumpalo::Bump; use rand::Rng; use tempfile::TempDir; -use bumpalo::Bump; fn setup_test_db() -> (TempDir, Arc) { let temp_dir = TempDir::new().unwrap(); let db_path = temp_dir.path().to_str().unwrap(); @@ -41,7 +35,8 @@ fn test_count_single_node() { let mut txn = storage.graph_env.write_txn().unwrap(); let person = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person = person.first().unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -59,13 +54,16 @@ fn test_count_node_array() { let mut txn = storage.graph_env.write_txn().unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -84,35 +82,28 @@ fn test_count_mixed_steps() { // Create a graph with multiple paths let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person1 = person1.first().unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person2 = person2.first().unwrap(); let person3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person3 = person3.first().unwrap(); G::new_mut(&storage, &arena, &mut txn) - .add_edge( - "knows", - None, - person1.id(), - person2.id(), - false, - ) - .collect::,_>>().unwrap(); + .add_edge("knows", None, person1.id(), person2.id(), false) + .collect::, _>>() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) - .add_edge( - "knows", - None, - person1.id(), - person3.id(), - false, - ) - .collect::,_>>().unwrap(); + .add_edge("knows", None, person1.id(), person3.id(), false) + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); println!("person1: {person1:?},\nperson2: {person2:?},\nperson3: {person3:?}"); @@ -148,7 +139,8 @@ fn test_count_filter_ref() { for _ in 0..100 { let node = G::new_mut(&storage, &arena, &mut txn) .add_n("Country", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); nodes.push(node); } let mut num_countries = 0; @@ -157,16 +149,12 @@ fn test_count_filter_ref() { for _ in 0..rand_num { let city = G::new_mut(&storage, &arena, &mut txn) .add_n("City", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) - .add_edge( - "Country_to_City", - None, - node.id(), - city.id(), - false, - ) - .collect::,_>>().unwrap(); + .add_edge("Country_to_City", None, node.id(), city.id(), false) + .collect::, _>>() + .unwrap(); // sleep for one microsecond std::thread::sleep(Duration::from_micros(1)); } @@ -185,17 +173,15 @@ fn test_count_filter_ref() { .out_node("Country_to_City") .count_to_val() .map_value_or(false, |v| { - println!( - "v: {v:?}, res: {:?}", - *v > 10 - ); + println!("v: {v:?}, res: {:?}", *v > 10); *v > 10 })?) } else { Ok(false) } }) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); println!("count: {count:?}, num_countries: {num_countries}"); diff --git a/helix-db/src/helix_engine/tests/traversal_tests/filter_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/filter_tests.rs index 20a259e9a..31ab18811 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/filter_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/filter_tests.rs @@ -39,13 +39,16 @@ fn test_filter_nodes() { // Create nodes with different properties let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 25 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 35 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -67,7 +70,8 @@ fn test_filter_nodes() { Ok(false) } }) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), person3.id()); } @@ -84,14 +88,16 @@ fn test_filter_macro_single_argument() { props_option(&arena, props! { "name" => "Alice" }), None, ) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n( "person", props_option(&arena, props! { "name" => "Bob" }), None, ) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); fn has_name(val: &Result) -> Result { if let Ok(TraversalValue::Node(node)) = val { @@ -106,7 +112,8 @@ fn test_filter_macro_single_argument() { let traversal = G::new(&storage, &txn, &arena) .n_from_type("person") .filter_ref(|val, _| has_name(val)) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 2); assert!( traversal @@ -131,10 +138,12 @@ fn test_filter_macro_multiple_arguments() { let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 25 }), None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 30 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); fn age_greater_than( @@ -160,7 +169,8 @@ fn test_filter_macro_multiple_arguments() { let traversal = G::new(&storage, &txn, &arena) .n_from_type("person") .filter_ref(|val, _| age_greater_than(val, 27)) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), person2.id()); @@ -174,10 +184,12 @@ fn test_filter_edges() { let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_edge( @@ -187,7 +199,8 @@ fn test_filter_edges() { person2.id(), false, ) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let edge2 = G::new_mut(&storage, &arena, &mut txn) .add_edge( "knows", @@ -196,7 +209,8 @@ fn test_filter_edges() { person1.id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -223,7 +237,8 @@ fn test_filter_edges() { let traversal = G::new(&storage, &txn, &arena) .e_from_type("knows") .filter_ref(|val, _| recent_edge(val, 2021)) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), edge2.id()); @@ -237,7 +252,8 @@ fn test_filter_empty_result() { let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 25 }), None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -258,7 +274,8 @@ fn test_filter_empty_result() { Ok(false) } }) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert!(traversal.is_empty()); } @@ -274,17 +291,20 @@ fn test_filter_chain() { props_option(&arena, props! { "age" => 25, "name" => "Alice" }), None, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n( "person", props_option(&arena, props! { "age" => 30, "name" => "Bob" }), None, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "age" => 35 }), None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -320,7 +340,8 @@ fn test_filter_chain() { .n_from_type("person") .filter_ref(|val, _| has_name(val)) .filter_ref(|val, _| age_greater_than(val, 27)) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), person2.id()); diff --git a/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs index 72ff80068..0d78b5441 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/node_traversal_tests.rs @@ -13,7 +13,7 @@ use crate::{ add_e::AddEAdapter, add_n::AddNAdapter, e_from_type::EFromTypeAdapter, n_from_id::NFromIdAdapter, n_from_type::NFromTypeAdapter, }, - util::{filter_ref::FilterRefAdapter, drop::Drop}, + util::{drop::Drop, filter_ref::FilterRefAdapter}, }, traversal_value::TraversalValue, }, @@ -88,23 +88,28 @@ fn test_out() { // Create graph: (person1)-[knows]->(person2)-[knows]->(person3) let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person1 = person1.first().unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person2 = person2.first().unwrap(); let person3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person3 = person3.first().unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person1.id(), person2.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person2.id(), person3.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -130,22 +135,26 @@ fn test_in() { // Create graph: (person1)-[knows]->(person2) let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person1 = person1.first().unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person2 = person2.first().unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person1.id(), person2.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); let nodes = G::new(&storage, &txn, &arena) .n_from_id(&person2.id()) .in_node("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); // Check that current step is at person1 @@ -167,26 +176,32 @@ fn test_complex_traversal() { let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person1 = person1.first().unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person2 = person2.first().unwrap(); let person3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let person3 = person3.first().unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person1.id(), person2.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("likes", None, person2.id(), person3.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("follows", None, person3.id(), person1.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -194,7 +209,8 @@ fn test_complex_traversal() { let nodes = G::new(&storage, &txn, &arena) .n_from_id(&person1.id()) .out_node("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Check that current step is at person2 assert_eq!(nodes.len(), 1); @@ -205,7 +221,8 @@ fn test_complex_traversal() { let nodes = G::new(&storage, &txn, &arena) .n_from_id(&node_id) .out_node("likes") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Check that current step is at person3 assert_eq!(nodes.len(), 1); @@ -216,7 +233,8 @@ fn test_complex_traversal() { let nodes = G::new(&storage, &txn, &arena) .n_from_id(&node_id) .out_node("follows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Check that current step is at person1 assert_eq!(nodes.len(), 1); @@ -232,14 +250,16 @@ fn test_n_from_id() { // Create a test node let person = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node_id = person.id(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); let count = G::new(&storage, &txn, &arena) .n_from_id(&node_id) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(count.len(), 1); } @@ -253,20 +273,24 @@ fn test_n_from_id_with_traversal() { // Create test graph: (person1)-[knows]->(person2) let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person1.id(), person2.id(), true) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); let count = G::new(&storage, &txn, &arena) .n_from_id(&person1.id()) .out_node("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Check that traversal reaches person2 assert_eq!(count.len(), 1); @@ -281,7 +305,8 @@ fn test_n_from_id_nonexistent() { let txn = storage.graph_env.read_txn().unwrap(); G::new(&storage, &txn, &arena) .n_from_id(&100) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); } #[test] @@ -293,23 +318,29 @@ fn test_n_from_id_chain_operations() { // Create test graph: (person1)-[knows]->(person2)-[likes]->(person3) let person1 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person2 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let _ = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let person3 = G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("knows", None, person1.id(), person2.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) .add_edge("likes", None, person2.id(), person3.id(), false) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); @@ -317,7 +348,8 @@ fn test_n_from_id_chain_operations() { .n_from_id(&person1.id()) .out_node("knows") .out_node("likes") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); // Check that the chain of traversals reaches person3 assert_eq!(nodes.len(), 1); @@ -336,7 +368,8 @@ fn test_with_id_type() { props_option(&arena, props! { "name" => "test" }), None, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); #[derive(Serialize, Deserialize, Debug)] struct Input { @@ -355,7 +388,8 @@ fn test_with_id_type() { let txn = storage.graph_env.read_txn().unwrap(); let traversal = G::new(&storage, &txn, &arena) .n_from_id(&input.id) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(traversal.len(), 1); assert_eq!(traversal[0].id(), input.id.inner()); @@ -374,7 +408,8 @@ fn test_double_add_and_double_fetch() { props_option(&arena, props! { "entity_name" => "person1" }), None, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let original_node2 = G::new_mut(db, &arena, &mut txn) .add_n( @@ -382,7 +417,8 @@ fn test_double_add_and_double_fetch() { props_option(&arena, props! { "entity_name" => "person2" }), None, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); @@ -398,7 +434,8 @@ fn test_double_add_and_double_fetch() { Ok(false) } }) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); let node2 = G::new(db, &txn, &arena) .n_from_type("person") @@ -411,7 +448,8 @@ fn test_double_add_and_double_fetch() { Ok(false) } }) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(node1.len(), 1); assert_eq!(node1[0].id(), original_node1.id()); @@ -426,14 +464,16 @@ fn test_double_add_and_double_fetch() { node2.first().unwrap().id(), false, ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let txn = db.graph_env.read_txn().unwrap(); let e = G::new(db, &txn, &arena) .e_from_type("knows") - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(e.len(), 1); assert_eq!(e[0].id(), e.first().unwrap().id()); if let TraversalValue::Edge(e) = &e[0] { diff --git a/helix-db/src/helix_engine/tests/traversal_tests/range_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/range_tests.rs index 503a03503..63fe75f3a 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/range_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/range_tests.rs @@ -1,26 +1,20 @@ -use std::sync::Arc; use super::test_utils::props_option; +use std::sync::Arc; -use tempfile::TempDir; -use bumpalo::Bump; use crate::{ helix_engine::{ storage_core::HelixGraphStorage, - traversal_core::{ - ops::{ - g::G, - out::out::OutAdapter, - source::{ - add_e::AddEAdapter, - add_n::AddNAdapter, - n_from_type::NFromTypeAdapter, - }, - util::range::RangeAdapter, - }, + traversal_core::ops::{ + g::G, + out::out::OutAdapter, + source::{add_e::AddEAdapter, add_n::AddNAdapter, n_from_type::NFromTypeAdapter}, + util::range::RangeAdapter, }, }, props, }; +use bumpalo::Bump; +use tempfile::TempDir; fn setup_test_db() -> (TempDir, Arc) { let temp_dir = TempDir::new().unwrap(); @@ -45,7 +39,8 @@ fn test_range_subset() { .map(|_| { G::new_mut(&storage, &arena, &mut txn) .add_n("person", None, None) - .collect::,_>>().unwrap() + .collect::, _>>() + .unwrap() .first() .unwrap(); }) @@ -72,7 +67,8 @@ fn test_range_chaining() { .map(|i| { G::new_mut(&storage, &arena, &mut txn) .add_n("person", props_option(&arena, props! { "name" => i }), None) - .collect::,_>>().unwrap() + .collect::, _>>() + .unwrap() .first() .unwrap() .clone() @@ -82,32 +78,23 @@ fn test_range_chaining() { // Create edges connecting nodes sequentially for i in 0..4 { G::new_mut(&storage, &arena, &mut txn) - .add_edge( - "knows", - None, - nodes[i].id(), - nodes[i + 1].id(), - false, - ) - .collect::,_>>().unwrap(); + .add_edge("knows", None, nodes[i].id(), nodes[i + 1].id(), false) + .collect::, _>>() + .unwrap(); } G::new_mut(&storage, &arena, &mut txn) - .add_edge( - "knows", - None, - nodes[4].id(), - nodes[0].id(), - false, - ) - .collect::,_>>().unwrap(); + .add_edge("knows", None, nodes[4].id(), nodes[0].id(), false) + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let txn = storage.graph_env.read_txn().unwrap(); let count = G::new(&storage, &txn, &arena) .n_from_type("person") // Get all nodes .range(0, 3) // Take first 3 nodes .out_node("knows") // Get their outgoing nodes - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(count.len(), 3); } @@ -121,7 +108,8 @@ fn test_range_empty() { let count = G::new(&storage, &txn, &arena) .n_from_type("person") // Get all nodes .range(0, 0) // Take first 3 nodes - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(count.len(), 0); } diff --git a/helix-db/src/helix_engine/tests/traversal_tests/secondary_index_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/secondary_index_tests.rs index 1b371d697..9abbb9a51 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/secondary_index_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/secondary_index_tests.rs @@ -50,25 +50,29 @@ fn test_delete_node_with_secondary_index() { props_option(&arena, props! { "name" => "John" }), Some(&["name"]), ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); let node_id = node.id(); G::new_mut_from_iter(&storage, &mut txn, std::iter::once(node), &arena) .update(&[("name", Value::from("Jane"))]) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let jane_nodes = G::new(&storage, &txn, &arena) .n_from_index("person", "name", &"Jane".to_string()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(jane_nodes.len(), 1); assert_eq!(jane_nodes[0].id(), node_id); let john_nodes = G::new(&storage, &txn, &arena) .n_from_index("person", "name", &"John".to_string()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert!(john_nodes.is_empty()); drop(txn); @@ -76,7 +80,8 @@ fn test_delete_node_with_secondary_index() { let txn = storage.graph_env.read_txn().unwrap(); let traversal = G::new(&storage, &txn, &arena) .n_from_id(&node_id) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); drop(txn); let mut txn = storage.graph_env.write_txn().unwrap(); @@ -87,7 +92,8 @@ fn test_delete_node_with_secondary_index() { let txn = storage.graph_env.read_txn().unwrap(); let node = G::new(&storage, &txn, &arena) .n_from_index("person", "name", &"Jane".to_string()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert!(node.is_empty()); } @@ -103,21 +109,24 @@ fn test_update_of_secondary_indices() { props_option(&arena, props! { "name" => "John" }), Some(&["name"]), ) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let mut txn = storage.graph_env.write_txn().unwrap(); G::new_mut_from_iter(&storage, &mut txn, std::iter::once(node), &arena) .update(&[("name", Value::from("Jane"))]) - .collect_to_obj().unwrap(); + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let nodes = G::new(&storage, &txn, &arena) .n_from_index("person", "name", &"Jane".to_string()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(nodes.len(), 1); if let TraversalValue::Node(node) = &nodes[0] { match node.properties.as_ref().unwrap().get("name").unwrap() { @@ -130,6 +139,7 @@ fn test_update_of_secondary_indices() { let john_nodes = G::new(&storage, &txn, &arena) .n_from_index("person", "name", &"John".to_string()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert!(john_nodes.is_empty()); } diff --git a/helix-db/src/helix_engine/tests/traversal_tests/test_utils.rs b/helix-db/src/helix_engine/tests/traversal_tests/test_utils.rs index 21c8094e8..1a4cff1a2 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/test_utils.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/test_utils.rs @@ -1,7 +1,4 @@ -use crate::{ - protocol::value::Value, - utils::properties::ImmutablePropertiesMap, -}; +use crate::{protocol::value::Value, utils::properties::ImmutablePropertiesMap}; use bumpalo::Bump; pub fn props_map<'arena>( diff --git a/helix-db/src/helix_engine/tests/traversal_tests/update_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/update_tests.rs index 4e02c02fc..81b3479a5 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/update_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/update_tests.rs @@ -16,8 +16,8 @@ use crate::{ traversal_value::TraversalValue, }, }, - protocol::value::Value, props, + protocol::value::Value, }; fn setup_test_db() -> (TempDir, Arc) { @@ -39,32 +39,45 @@ fn test_update_node() { let mut txn = storage.graph_env.write_txn().unwrap(); let node = G::new_mut(&storage, &arena, &mut txn) - .add_n("person", props_option(&arena, props!("name" => "test")), None) - .collect_to_obj().unwrap(); + .add_n( + "person", + props_option(&arena, props!("name" => "test")), + None, + ) + .collect_to_obj() + .unwrap(); G::new_mut(&storage, &arena, &mut txn) - .add_n("person", props_option(&arena, props!("name" => "test2")), None) - .collect_to_obj().unwrap(); + .add_n( + "person", + props_option(&arena, props!("name" => "test2")), + None, + ) + .collect_to_obj() + .unwrap(); txn.commit().unwrap(); let arena_read = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let traversal = G::new(&storage, &txn, &arena_read) .n_from_id(&node.id()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); drop(txn); let arena = Bump::new(); let mut txn = storage.graph_env.write_txn().unwrap(); G::new_mut_from_iter(&storage, &mut txn, traversal.into_iter(), &arena) .update(&[("name", Value::from("john"))]) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); txn.commit().unwrap(); let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let updated = G::new(&storage, &txn, &arena) .n_from_id(&node.id()) - .collect::,_>>().unwrap(); + .collect::, _>>() + .unwrap(); assert_eq!(updated.len(), 1); match &updated[0] { diff --git a/helix-db/src/helix_engine/traversal_core/config.rs b/helix-db/src/helix_engine/traversal_core/config.rs index 2ffd3db3a..3897dbcd0 100644 --- a/helix-db/src/helix_engine/traversal_core/config.rs +++ b/helix-db/src/helix_engine/traversal_core/config.rs @@ -96,7 +96,7 @@ impl Config { let config = std::fs::read_to_string(config_path)?; let mut config = sonic_rs::from_str::(&config)?; - + // Schema will be populated from INTROSPECTION_DATA during code generation config.schema = None; diff --git a/helix-db/src/helix_engine/traversal_core/ops/bm25/hybrid_search_bm25.rs b/helix-db/src/helix_engine/traversal_core/ops/bm25/hybrid_search_bm25.rs index b7af7680c..c1006b27c 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/bm25/hybrid_search_bm25.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/bm25/hybrid_search_bm25.rs @@ -79,4 +79,3 @@ impl<'a, I: Iterator>> HybridSearchBM2 } } */ - diff --git a/helix-db/src/helix_engine/traversal_core/ops/bm25/mod.rs b/helix-db/src/helix_engine/traversal_core/ops/bm25/mod.rs index 6395c339a..2e709c9bf 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/bm25/mod.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/bm25/mod.rs @@ -1,3 +1,2 @@ -pub mod search_bm25; pub mod hybrid_search_bm25; - +pub mod search_bm25; diff --git a/helix-db/src/helix_engine/traversal_core/ops/in_/mod.rs b/helix-db/src/helix_engine/traversal_core/ops/in_/mod.rs index 13057d6d4..afdee8206 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/in_/mod.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/in_/mod.rs @@ -1,4 +1,4 @@ pub mod in_; pub mod in_e; pub mod to_n; -pub mod to_v; \ No newline at end of file +pub mod to_v; diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/mod.rs b/helix-db/src/helix_engine/traversal_core/ops/source/mod.rs index 774d93065..75293c678 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/mod.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/mod.rs @@ -7,4 +7,4 @@ pub mod n_from_id; pub mod n_from_index; pub mod n_from_type; pub mod v_from_id; -pub mod v_from_type; \ No newline at end of file +pub mod v_from_type; diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/n_from_index.rs b/helix-db/src/helix_engine/traversal_core/ops/source/n_from_index.rs index e720fd6a1..064312f55 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/n_from_index.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/n_from_index.rs @@ -1,9 +1,13 @@ use crate::{ helix_engine::{ - traversal_core::{traversal_iter::RoTraversalIterator, traversal_value::TraversalValue, LMDB_STRING_HEADER_LENGTH}, + traversal_core::{ + LMDB_STRING_HEADER_LENGTH, traversal_iter::RoTraversalIterator, + traversal_value::TraversalValue, + }, types::GraphError, }, - protocol::value::Value, utils::items::Node, + protocol::value::Value, + utils::items::Node, }; use serde::Serialize; @@ -79,18 +83,18 @@ impl< ); let length_of_label_in_lmdb = u64::from_le_bytes(value[..LMDB_STRING_HEADER_LENGTH].try_into().unwrap()) as usize; - + if length_of_label_in_lmdb != label.len() { return None; } - + assert!( value.len() >= length_of_label_in_lmdb + LMDB_STRING_HEADER_LENGTH, "value length is not at least the header length plus the label length meaning there has been a corruption on node insertion" ); let label_in_lmdb = &value[LMDB_STRING_HEADER_LENGTH ..LMDB_STRING_HEADER_LENGTH + length_of_label_in_lmdb]; - + if label_in_lmdb == label_as_bytes { match Node::<'arena>::from_bincode_bytes(node_id, value, self.arena) { Ok(node) => { @@ -104,10 +108,10 @@ impl< } else { return None; } - + } None - + }); diff --git a/helix-db/src/helix_engine/traversal_core/ops/source/n_from_type.rs b/helix-db/src/helix_engine/traversal_core/ops/source/n_from_type.rs index 90c9fdc95..dfd49f242 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/source/n_from_type.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/source/n_from_type.rs @@ -15,7 +15,7 @@ pub trait NFromTypeAdapter<'db, 'arena, 'txn, 's>: /// Returns an iterator containing the nodes with the given label. /// /// Note that the `label` cannot be empty and must be a valid, existing node label.' - /// + /// /// The label is stored before the node properties in LMDB. /// Bincode assures that the fields of a struct are stored in the same order as they are defined in the struct (first to last). /// @@ -58,18 +58,18 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr ); let length_of_label_in_lmdb = u64::from_le_bytes(value[..LMDB_STRING_HEADER_LENGTH].try_into().unwrap()) as usize; - + if length_of_label_in_lmdb != label.len() { return None; } - + assert!( value.len() >= length_of_label_in_lmdb + LMDB_STRING_HEADER_LENGTH, "value length is not at least the header length plus the label length meaning there has been a corruption on node insertion" ); let label_in_lmdb = &value[LMDB_STRING_HEADER_LENGTH ..LMDB_STRING_HEADER_LENGTH + length_of_label_in_lmdb]; - + if label_in_lmdb == label_as_bytes { match Node::<'arena>::from_bincode_bytes(id, value, self.arena) { Ok(node) => { diff --git a/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs b/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs index d562a91fd..afb2d8d21 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs @@ -5,7 +5,10 @@ use crate::{ types::GraphError, }, protocol::value::Value, - utils::{items::{Edge, Node}, label_hash::hash_label}, + utils::{ + items::{Edge, Node}, + label_hash::hash_label, + }, }; use heed3::RoTxn; use std::{ @@ -85,8 +88,14 @@ pub enum PathAlgorithm { AStar, } -pub struct ShortestPathIterator<'db, 'arena, 'txn, I, F, H = fn(&Node<'arena>) -> Result> -where +pub struct ShortestPathIterator< + 'db, + 'arena, + 'txn, + I, + F, + H = fn(&Node<'arena>) -> Result, +> where 'db: 'arena, 'arena: 'txn, F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, @@ -391,7 +400,7 @@ where None => { return Some(Err(GraphError::TraversalError( "A* algorithm requires a heuristic function".to_string(), - ))) + ))); } }; @@ -590,7 +599,13 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, >, > { - self.shortest_path_with_algorithm(edge_label, from, to, PathAlgorithm::BFS, default_weight_fn) + self.shortest_path_with_algorithm( + edge_label, + from, + to, + PathAlgorithm::BFS, + default_weight_fn, + ) } #[inline] diff --git a/helix-db/src/helix_engine/traversal_core/traversal_iter.rs b/helix-db/src/helix_engine/traversal_core/traversal_iter.rs index 5cab22e6f..22b58fae0 100644 --- a/helix-db/src/helix_engine/traversal_core/traversal_iter.rs +++ b/helix-db/src/helix_engine/traversal_core/traversal_iter.rs @@ -49,7 +49,9 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE } pub fn collect_to_obj(mut self) -> Result, GraphError> { - self.inner.next().unwrap_or(Err(GraphError::New("No value found".to_string()))) + self.inner + .next() + .unwrap_or(Err(GraphError::New("No value found".to_string()))) } pub fn collect_to_value(self) -> Value { @@ -64,7 +66,6 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE default: bool, f: impl Fn(&Value) -> bool, ) -> Result { - match &self.inner.next() { Some(Ok(TraversalValue::Value(val))) => Ok(f(val)), Some(Ok(_)) => Err(GraphError::ConversionError( @@ -130,7 +131,9 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE } pub fn collect_to_obj(mut self) -> Result, GraphError> { - self.inner.next().unwrap_or(Err(GraphError::New("No value found".to_string()))) + self.inner + .next() + .unwrap_or(Err(GraphError::New("No value found".to_string()))) } pub fn map_value_or( @@ -138,7 +141,6 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE default: bool, f: impl Fn(&Value) -> bool, ) -> Result { - match &self.inner.next() { Some(Ok(TraversalValue::Value(val))) => Ok(f(val)), Some(Ok(_)) => Err(GraphError::ConversionError( diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs index 1efc3edca..ee18236a5 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_avx.rs @@ -17,106 +17,110 @@ unsafe fn hsum256_ps_avx(x: __m256) -> f32 { pub(crate) unsafe fn euclid_similarity_avx( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { unsafe { - // It is safe to load unaligned floats from a pointer. - // - - let n = v1.len(); - let m = n - (n % 32); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum256_1: __m256 = _mm256_setzero_ps(); - let mut sum256_2: __m256 = _mm256_setzero_ps(); - let mut sum256_3: __m256 = _mm256_setzero_ps(); - let mut sum256_4: __m256 = _mm256_setzero_ps(); - let mut i: usize = 0; - while i < m { - let sub256_1: __m256 = - _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); - sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); - - let sub256_2: __m256 = - _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); - sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); - - let sub256_3: __m256 = - _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); - sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); - - let sub256_4: __m256 = - _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); - sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); - - ptr1 = ptr1.add(32); - ptr2 = ptr2.add(32); - i += 32; - } +) -> f32 { + unsafe { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub256_1: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); + sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); + + let sub256_2: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); + sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); + + let sub256_3: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); + sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); + + let sub256_4: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); + sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } - let mut result = hsum256_ps_avx(sum256_1) - + hsum256_ps_avx(sum256_2) - + hsum256_ps_avx(sum256_3) - + hsum256_ps_avx(sum256_4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += (a - b).powi(2); + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result } - result -}} +} #[target_feature(enable = "avx")] #[target_feature(enable = "fma")] pub(crate) unsafe fn dot_similarity_avx( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { unsafe { - // It is safe to load unaligned floats from a pointer. - // - - let n = v1.len(); - let m = n - (n % 32); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum256_1: __m256 = _mm256_setzero_ps(); - let mut sum256_2: __m256 = _mm256_setzero_ps(); - let mut sum256_3: __m256 = _mm256_setzero_ps(); - let mut sum256_4: __m256 = _mm256_setzero_ps(); - let mut i: usize = 0; - while i < m { - sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1); - sum256_2 = _mm256_fmadd_ps( - _mm256_loadu_ps(ptr1.add(8)), - _mm256_loadu_ps(ptr2.add(8)), - sum256_2, - ); - sum256_3 = _mm256_fmadd_ps( - _mm256_loadu_ps(ptr1.add(16)), - _mm256_loadu_ps(ptr2.add(16)), - sum256_3, - ); - sum256_4 = _mm256_fmadd_ps( - _mm256_loadu_ps(ptr1.add(24)), - _mm256_loadu_ps(ptr2.add(24)), - sum256_4, - ); - - ptr1 = ptr1.add(32); - ptr2 = ptr2.add(32); - i += 32; - } +) -> f32 { + unsafe { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1); + sum256_2 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(8)), + _mm256_loadu_ps(ptr2.add(8)), + sum256_2, + ); + sum256_3 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(16)), + _mm256_loadu_ps(ptr2.add(16)), + sum256_3, + ); + sum256_4 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(24)), + _mm256_loadu_ps(ptr2.add(24)), + sum256_4, + ); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } - let mut result = hsum256_ps_avx(sum256_1) - + hsum256_ps_avx(sum256_2) - + hsum256_ps_avx(sum256_3) - + hsum256_ps_avx(sum256_4); + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += a * b; + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result } - result -}} +} #[cfg(test)] mod tests { diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs index 1894fadd7..05665fa7e 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_neon.rs @@ -8,53 +8,55 @@ pub(crate) unsafe fn euclid_similarity_neon( v1: &UnalignedVector, v2: &UnalignedVector, ) -> f32 { - // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types - // from potentially unaligned memory locations safely. - // https://github.com/meilisearch/arroy/pull/13 - - let n = v1.len(); - let m = n - (n % 16); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum1 = vdupq_n_f32(0.); - let mut sum2 = vdupq_n_f32(0.); - let mut sum3 = vdupq_n_f32(0.); - let mut sum4 = vdupq_n_f32(0.); - - let mut i: usize = 0; - while i < m { - let sub1 = vsubq_f32(unaligned_float32x4_t(ptr1), unaligned_float32x4_t(ptr2)); - sum1 = vfmaq_f32(sum1, sub1, sub1); - - let sub2 = vsubq_f32( - unaligned_float32x4_t(ptr1.add(4)), - unaligned_float32x4_t(ptr2.add(4)), - ); - sum2 = vfmaq_f32(sum2, sub2, sub2); - - let sub3 = vsubq_f32( - unaligned_float32x4_t(ptr1.add(8)), - unaligned_float32x4_t(ptr2.add(8)), - ); - sum3 = vfmaq_f32(sum3, sub3, sub3); - - let sub4 = vsubq_f32( - unaligned_float32x4_t(ptr1.add(12)), - unaligned_float32x4_t(ptr2.add(12)), - ); - sum4 = vfmaq_f32(sum4, sub4, sub4); - - ptr1 = ptr1.add(16); - ptr2 = ptr2.add(16); - i += 16; - } - let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += (a - b).powi(2); + unsafe { + // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types + // from potentially unaligned memory locations safely. + // https://github.com/meilisearch/arroy/pull/13 + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let sub1 = vsubq_f32(unaligned_float32x4_t(ptr1), unaligned_float32x4_t(ptr2)); + sum1 = vfmaq_f32(sum1, sub1, sub1); + + let sub2 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(4)), + unaligned_float32x4_t(ptr2.add(4)), + ); + sum2 = vfmaq_f32(sum2, sub2, sub2); + + let sub3 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(8)), + unaligned_float32x4_t(ptr2.add(8)), + ); + sum3 = vfmaq_f32(sum3, sub3, sub3); + + let sub4 = vsubq_f32( + unaligned_float32x4_t(ptr1.add(12)), + unaligned_float32x4_t(ptr2.add(12)), + ); + sum4 = vfmaq_f32(sum4, sub4, sub4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result } - result } #[cfg(target_feature = "neon")] @@ -62,57 +64,59 @@ pub(crate) unsafe fn dot_similarity_neon( v1: &UnalignedVector, v2: &UnalignedVector, ) -> f32 { - // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types - // from potentially unaligned memory locations safely. - // https://github.com/meilisearch/arroy/pull/13 - - let n = v1.len(); - let m = n - (n % 16); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum1 = vdupq_n_f32(0.); - let mut sum2 = vdupq_n_f32(0.); - let mut sum3 = vdupq_n_f32(0.); - let mut sum4 = vdupq_n_f32(0.); - - let mut i: usize = 0; - while i < m { - sum1 = vfmaq_f32( - sum1, - unaligned_float32x4_t(ptr1), - unaligned_float32x4_t(ptr2), - ); - sum2 = vfmaq_f32( - sum2, - unaligned_float32x4_t(ptr1.add(4)), - unaligned_float32x4_t(ptr2.add(4)), - ); - sum3 = vfmaq_f32( - sum3, - unaligned_float32x4_t(ptr1.add(8)), - unaligned_float32x4_t(ptr2.add(8)), - ); - sum4 = vfmaq_f32( - sum4, - unaligned_float32x4_t(ptr1.add(12)), - unaligned_float32x4_t(ptr2.add(12)), - ); - ptr1 = ptr1.add(16); - ptr2 = ptr2.add(16); - i += 16; - } - let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += a * b; + unsafe { + // We use the unaligned_float32x4_t helper function to read f32x4 NEON SIMD types + // from potentially unaligned memory locations safely. + // https://github.com/meilisearch/arroy/pull/13 + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + sum1 = vfmaq_f32( + sum1, + unaligned_float32x4_t(ptr1), + unaligned_float32x4_t(ptr2), + ); + sum2 = vfmaq_f32( + sum2, + unaligned_float32x4_t(ptr1.add(4)), + unaligned_float32x4_t(ptr2.add(4)), + ); + sum3 = vfmaq_f32( + sum3, + unaligned_float32x4_t(ptr1.add(8)), + unaligned_float32x4_t(ptr2.add(8)), + ); + sum4 = vfmaq_f32( + sum4, + unaligned_float32x4_t(ptr1.add(12)), + unaligned_float32x4_t(ptr2.add(12)), + ); + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result } - result } /// Reads 4xf32 in a stack-located array aligned on a f32 and reads a `float32x4_t` from it. unsafe fn unaligned_float32x4_t(ptr: *const f32) -> float32x4_t { - vld1q_f32(read_unaligned(ptr as *const [f32; 4]).as_ptr()) + unsafe { vld1q_f32(read_unaligned(ptr as *const [f32; 4]).as_ptr()) } } #[cfg(test)] @@ -125,9 +129,6 @@ mod tests { #[cfg(target_feature = "neon")] #[test] fn test_spaces_neon() { - use super::*; - use crate::spaces::simple::*; - if std::arch::is_aarch64_feature_detected!("neon") { let v1: Vec = vec![ 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., diff --git a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs index ec018afa4..c24b4e263 100644 --- a/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs +++ b/helix-db/src/helix_engine/vector_core/spaces/simple_sse.rs @@ -17,101 +17,105 @@ unsafe fn hsum128_ps_sse(x: __m128) -> f32 { pub(crate) unsafe fn euclid_similarity_sse( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { unsafe { - // It is safe to load unaligned floats from a pointer. - // - - let n = v1.len(); - let m = n - (n % 16); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum128_1: __m128 = _mm_setzero_ps(); - let mut sum128_2: __m128 = _mm_setzero_ps(); - let mut sum128_3: __m128 = _mm_setzero_ps(); - let mut sum128_4: __m128 = _mm_setzero_ps(); - let mut i: usize = 0; - while i < m { - let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); - sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); - - let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); - sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); - - let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); - sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); - - let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); - sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); - - ptr1 = ptr1.add(16); - ptr2 = ptr2.add(16); - i += 16; - } +) -> f32 { + unsafe { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); + sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); + + let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); + sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); + + let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); + sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); + + let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); + sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } - let mut result = hsum128_ps_sse(sum128_1) - + hsum128_ps_sse(sum128_2) - + hsum128_ps_sse(sum128_3) - + hsum128_ps_sse(sum128_4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += (a - b).powi(2); + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += (a - b).powi(2); + } + result } - result -}} +} #[target_feature(enable = "sse")] pub(crate) unsafe fn dot_similarity_sse( v1: &UnalignedVector, v2: &UnalignedVector, -) -> f32 { unsafe { - // It is safe to load unaligned floats from a pointer. - // - - let n = v1.len(); - let m = n - (n % 16); - let mut ptr1 = v1.as_ptr() as *const f32; - let mut ptr2 = v2.as_ptr() as *const f32; - let mut sum128_1: __m128 = _mm_setzero_ps(); - let mut sum128_2: __m128 = _mm_setzero_ps(); - let mut sum128_3: __m128 = _mm_setzero_ps(); - let mut sum128_4: __m128 = _mm_setzero_ps(); - - let mut i: usize = 0; - while i < m { - sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1); - - sum128_2 = _mm_add_ps( - _mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))), - sum128_2, - ); - - sum128_3 = _mm_add_ps( - _mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))), - sum128_3, - ); - - sum128_4 = _mm_add_ps( - _mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))), - sum128_4, - ); - - ptr1 = ptr1.add(16); - ptr2 = ptr2.add(16); - i += 16; - } +) -> f32 { + unsafe { + // It is safe to load unaligned floats from a pointer. + // + + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1 = v1.as_ptr() as *const f32; + let mut ptr2 = v2.as_ptr() as *const f32; + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1); + + sum128_2 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))), + sum128_2, + ); + + sum128_3 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))), + sum128_3, + ); + + sum128_4 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))), + sum128_4, + ); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } - let mut result = hsum128_ps_sse(sum128_1) - + hsum128_ps_sse(sum128_2) - + hsum128_ps_sse(sum128_3) - + hsum128_ps_sse(sum128_4); - for i in 0..n - m { - let a = read_unaligned(ptr1.add(i)); - let b = read_unaligned(ptr2.add(i)); - result += a * b; + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + let a = read_unaligned(ptr1.add(i)); + let b = read_unaligned(ptr2.add(i)); + result += a * b; + } + result } - result -}} +} #[cfg(test)] mod tests { diff --git a/helix-db/src/helix_gateway/builtin/node_by_id.rs b/helix-db/src/helix_gateway/builtin/node_by_id.rs index 3cd06f42c..455e58c26 100644 --- a/helix-db/src/helix_gateway/builtin/node_by_id.rs +++ b/helix-db/src/helix_gateway/builtin/node_by_id.rs @@ -128,25 +128,22 @@ inventory::submit! { #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; - use tempfile::TempDir; - use axum::body::Bytes; use crate::{ helix_engine::{ storage_core::version_info::VersionInfo, traversal_core::{ HelixGraphEngine, HelixGraphEngineOpts, config::Config, - ops::{ - g::G, - source::add_n::AddNAdapter, - }, + ops::{g::G, source::add_n::AddNAdapter}, }, }, - protocol::{request::Request, request::RequestType, Format, value::Value}, helix_gateway::router::router::HandlerInput, + protocol::{Format, request::Request, request::RequestType, value::Value}, utils::id::ID, }; + use axum::body::Bytes; + use std::sync::Arc; + use tempfile::TempDir; fn setup_test_engine() -> (HelixGraphEngine, TempDir) { let temp_dir = TempDir::new().unwrap(); @@ -171,7 +168,9 @@ mod tests { let props = [("name", Value::String("Alice".to_string()))]; let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -196,7 +195,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_details_inner(input); @@ -227,7 +225,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_details_inner(input); @@ -256,7 +253,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_details_inner(input); @@ -279,7 +275,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_details_inner(input); @@ -294,11 +289,15 @@ mod tests { let mut txn = engine.storage.graph_env.write_txn().unwrap(); let arena = bumpalo::Bump::new(); - let props = [("name", Value::String("Alice".to_string())), - ("age", Value::I64(30))]; + let props = [ + ("name", Value::String("Alice".to_string())), + ("age", Value::I64(30)), + ]; let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -323,7 +322,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_details_inner(input); diff --git a/helix-db/src/helix_gateway/builtin/node_connections.rs b/helix-db/src/helix_gateway/builtin/node_connections.rs index 5937b3791..1e972a361 100644 --- a/helix-db/src/helix_gateway/builtin/node_connections.rs +++ b/helix-db/src/helix_gateway/builtin/node_connections.rs @@ -95,9 +95,10 @@ pub fn node_connections_inner(input: HandlerInput) -> Result match HelixGraphStorage::unpack_adj_edge_data(value) { Ok((edge_id, from_node)) => { if connected_node_ids.insert(from_node) - && let Ok(node) = db.get_node(&txn, &from_node, &arena) { - connected_nodes.push(TraversalValue::Node(node)); - } + && let Ok(node) = db.get_node(&txn, &from_node, &arena) + { + connected_nodes.push(TraversalValue::Node(node)); + } match db.get_edge(&txn, &edge_id, &arena) { Ok(edge) => Some(TraversalValue::Edge(edge)), @@ -117,9 +118,10 @@ pub fn node_connections_inner(input: HandlerInput) -> Result match HelixGraphStorage::unpack_adj_edge_data(value) { Ok((edge_id, to_node)) => { if connected_node_ids.insert(to_node) - && let Ok(node) = db.get_node(&txn, &to_node, &arena) { - connected_nodes.push(TraversalValue::Node(node)); - } + && let Ok(node) = db.get_node(&txn, &to_node, &arena) + { + connected_nodes.push(TraversalValue::Node(node)); + } match db.get_edge(&txn, &edge_id, &arena) { Ok(edge) => Some(TraversalValue::Edge(edge)), @@ -208,9 +210,6 @@ inventory::submit! { #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; - use tempfile::TempDir; - use axum::body::Bytes; use crate::{ helix_engine::{ storage_core::version_info::VersionInfo, @@ -219,17 +218,17 @@ mod tests { config::Config, ops::{ g::G, - source::{ - add_e::AddEAdapter, - add_n::AddNAdapter, - }, + source::{add_e::AddEAdapter, add_n::AddNAdapter}, }, }, }, - protocol::{request::Request, request::RequestType, Format}, helix_gateway::router::router::HandlerInput, + protocol::{Format, request::Request, request::RequestType}, utils::id::ID, }; + use axum::body::Bytes; + use std::sync::Arc; + use tempfile::TempDir; fn setup_test_engine() -> (HelixGraphEngine, TempDir) { let temp_dir = TempDir::new().unwrap(); @@ -258,7 +257,13 @@ mod tests { .collect_to_obj()?; let _edge = G::new_mut(&engine.storage, &arena, &mut txn) - .add_edge(arena.alloc_str("knows"), None, node1.id(), node2.id(), false) + .add_edge( + arena.alloc_str("knows"), + None, + node1.id(), + node2.id(), + false, + ) .collect_to_obj()?; txn.commit().unwrap(); @@ -278,7 +283,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_connections_inner(input); @@ -306,7 +310,13 @@ mod tests { .collect_to_obj()?; let _edge = G::new_mut(&engine.storage, &arena, &mut txn) - .add_edge(arena.alloc_str("knows"), None, node1.id(), node2.id(), false) + .add_edge( + arena.alloc_str("knows"), + None, + node1.id(), + node2.id(), + false, + ) .collect_to_obj()?; txn.commit().unwrap(); @@ -326,7 +336,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_connections_inner(input); @@ -365,7 +374,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_connections_inner(input); @@ -397,7 +405,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_connections_inner(input); @@ -420,7 +427,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = node_connections_inner(input); diff --git a/helix-db/src/helix_gateway/builtin/nodes_by_label.rs b/helix-db/src/helix_gateway/builtin/nodes_by_label.rs index 820eba754..44caea5b2 100644 --- a/helix-db/src/helix_gateway/builtin/nodes_by_label.rs +++ b/helix-db/src/helix_gateway/builtin/nodes_by_label.rs @@ -108,9 +108,10 @@ pub fn nodes_by_label_inner(input: HandlerInput) -> Result= limit_count { - break; - } + && count >= limit_count + { + break; + } } } Err(_) => continue, @@ -137,24 +138,21 @@ inventory::submit! { #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; - use tempfile::TempDir; - use axum::body::Bytes; use crate::{ helix_engine::{ storage_core::version_info::VersionInfo, traversal_core::{ HelixGraphEngine, HelixGraphEngineOpts, config::Config, - ops::{ - g::G, - source::add_n::AddNAdapter, - }, + ops::{g::G, source::add_n::AddNAdapter}, }, }, - protocol::{request::Request, request::RequestType, Format, value::Value}, helix_gateway::router::router::HandlerInput, + protocol::{Format, request::Request, request::RequestType, value::Value}, }; + use axum::body::Bytes; + use std::sync::Arc; + use tempfile::TempDir; fn setup_test_engine() -> (HelixGraphEngine, TempDir) { let temp_dir = TempDir::new().unwrap(); @@ -179,7 +177,9 @@ mod tests { let props1 = [("name", Value::String("Alice".to_string()))]; let props_map1 = ImmutablePropertiesMap::new( props1.len(), - props1.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props1 + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -190,7 +190,9 @@ mod tests { let props2 = [("name", Value::String("Bob".to_string()))]; let props_map2 = ImmutablePropertiesMap::new( props2.len(), - props2.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props2 + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -214,7 +216,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_by_label_inner(input); @@ -238,7 +239,9 @@ mod tests { let props = [("index", Value::I64(i))]; let props_map = ImmutablePropertiesMap::new( props.len(), - props.iter().map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), + props + .iter() + .map(|(k, v)| (arena.alloc_str(k) as &str, v.clone())), &arena, ); @@ -263,7 +266,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_by_label_inner(input); @@ -293,7 +295,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_by_label_inner(input); @@ -320,7 +321,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_by_label_inner(input); @@ -357,7 +357,6 @@ mod tests { let input = HandlerInput { graph: Arc::new(engine), request, - }; let result = nodes_by_label_inner(input); diff --git a/helix-db/src/helix_gateway/introspect_schema.rs b/helix-db/src/helix_gateway/introspect_schema.rs index 8eff78e61..ad52a2e3c 100644 --- a/helix-db/src/helix_gateway/introspect_schema.rs +++ b/helix-db/src/helix_gateway/introspect_schema.rs @@ -18,4 +18,3 @@ pub async fn introspect_schema_handler( _ => (StatusCode::INTERNAL_SERVER_ERROR, "Could not find schema").into_response(), } } - diff --git a/helix-db/src/helix_gateway/router/router.rs b/helix-db/src/helix_gateway/router/router.rs index 7197a6b89..7ab6911d1 100644 --- a/helix-db/src/helix_gateway/router/router.rs +++ b/helix-db/src/helix_gateway/router/router.rs @@ -47,8 +47,6 @@ impl Debug for IoContFn { } } - - // basic type for function pointer pub type BasicHandlerFn = fn(HandlerInput) -> Result; diff --git a/helix-db/src/helix_gateway/tests/gateway_tests.rs b/helix-db/src/helix_gateway/tests/gateway_tests.rs index ddfbec850..cfd0f25e0 100644 --- a/helix-db/src/helix_gateway/tests/gateway_tests.rs +++ b/helix-db/src/helix_gateway/tests/gateway_tests.rs @@ -336,11 +336,11 @@ fn test_gateway_opts_default_workers_per_core() { #[cfg(feature = "api-key")] mod api_key_tests { - + use crate::helix_gateway::key_verification::verify_key; + use crate::protocol::Format; use crate::protocol::{HelixError, request::Request}; use axum::body::Bytes; - use crate::protocol::Format; #[test] fn test_verify_key_integration_success() { diff --git a/helix-db/src/helix_gateway/tests/mod.rs b/helix-db/src/helix_gateway/tests/mod.rs index 4c6449d42..c42965735 100644 --- a/helix-db/src/helix_gateway/tests/mod.rs +++ b/helix-db/src/helix_gateway/tests/mod.rs @@ -3,5 +3,5 @@ pub mod gateway_tests; pub mod introspect_schema_tests; pub mod mcp_tests; pub mod router_tests; -pub mod worker_pool_tests; pub mod worker_pool_concurrency_tests; +pub mod worker_pool_tests; diff --git a/helix-db/src/helix_gateway/tests/worker_pool_concurrency_tests.rs b/helix-db/src/helix_gateway/tests/worker_pool_concurrency_tests.rs index e4379e486..5a847edb5 100644 --- a/helix-db/src/helix_gateway/tests/worker_pool_concurrency_tests.rs +++ b/helix-db/src/helix_gateway/tests/worker_pool_concurrency_tests.rs @@ -1,3 +1,4 @@ +use crate::helix_engine::traversal_core::HelixGraphEngine; /// Concurrency-specific tests for WorkerPool /// /// This test suite focuses on concurrent behavior and race conditions in the WorkerPool. @@ -16,20 +17,15 @@ /// - Worker fairness under load /// - No coordination between workers accessing shared graph /// - No deadlocks or livelocks under high concurrency - use crate::helix_engine::traversal_core::HelixGraphEngineOpts; use crate::helix_engine::traversal_core::config::Config; -use crate::helix_engine::{traversal_core::HelixGraphEngine}; use crate::helix_gateway::worker_pool::WorkerPool; -use crate::helix_gateway::{ - gateway::CoreSetter, - router::router::HelixRouter, -}; +use crate::helix_gateway::{gateway::CoreSetter, router::router::HelixRouter}; use crate::protocol::Format; use crate::protocol::{Request, request::RequestType}; use axum::body::Bytes; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use tempfile::TempDir; use tokio::time::timeout; @@ -56,7 +52,10 @@ fn create_request(name: &str) -> Request { } } -fn create_test_pool(num_cores: usize, threads_per_core: usize) -> (WorkerPool, Arc, TempDir) { +fn create_test_pool( + num_cores: usize, + threads_per_core: usize, +) -> (WorkerPool, Arc, TempDir) { let (graph, temp_dir) = create_test_graph(); let router = Arc::new(HelixRouter::new(None, None)); let rt = Arc::new( @@ -106,8 +105,11 @@ async fn test_concurrent_requests_high_load() { } // All should complete (no panics or hangs) - assert_eq!(completed, num_concurrent, - "All requests should complete, got {}/{}", completed, num_concurrent); + assert_eq!( + completed, num_concurrent, + "All requests should complete, got {}/{}", + completed, num_concurrent + ); println!("High load test: {} requests completed", num_concurrent); } @@ -223,7 +225,10 @@ async fn test_parity_mechanism_both_workers() { // All should complete if parity mechanism allows all workers to participate assert_eq!(completed, num_requests); - println!("Parity test: {} requests completed across even/odd workers", completed); + println!( + "Parity test: {} requests completed across even/odd workers", + completed + ); } #[tokio::test] @@ -326,7 +331,10 @@ async fn test_concurrent_different_request_types() { let expected = 25 * request_types.len(); assert_eq!(completed, expected); - println!("Different request types: {}/{} completed", completed, expected); + println!( + "Different request types: {}/{} completed", + completed, expected + ); } #[tokio::test] @@ -396,6 +404,9 @@ async fn test_worker_distribution_fairness() { println!("Fairness test: 100 requests completed in {:?}", elapsed); // Basic sanity: should complete in reasonable time - assert!(elapsed < Duration::from_secs(10), - "Requests took {:?}, may indicate poor distribution", elapsed); + assert!( + elapsed < Duration::from_secs(10), + "Requests took {:?}, may indicate poor distribution", + elapsed + ); } diff --git a/helix-db/src/helixc/analyzer/diagnostic.rs b/helix-db/src/helixc/analyzer/diagnostic.rs index d3a33f673..86dc0da99 100644 --- a/helix-db/src/helixc/analyzer/diagnostic.rs +++ b/helix-db/src/helixc/analyzer/diagnostic.rs @@ -1,9 +1,5 @@ use crate::helixc::{ - analyzer::{ - error_codes::ErrorCode, - fix::Fix, - pretty, - }, + analyzer::{error_codes::ErrorCode, fix::Fix, pretty}, parser::location::Loc, }; diff --git a/helix-db/src/helixc/analyzer/error_codes.rs b/helix-db/src/helixc/analyzer/error_codes.rs index 188c46937..ddd12d966 100644 --- a/helix-db/src/helixc/analyzer/error_codes.rs +++ b/helix-db/src/helixc/analyzer/error_codes.rs @@ -1,5 +1,5 @@ use paste::paste; -use std::fmt::{Debug}; +use std::fmt::Debug; #[allow(dead_code)] #[derive(Debug, Clone, PartialEq)] @@ -119,7 +119,6 @@ pub enum ErrorCode { /// `E653` - `inner type of in variable is not an object` E653, - /// `W101` - `query has no return` W101, } diff --git a/helix-db/src/helixc/analyzer/methods/exclude_validation.rs b/helix-db/src/helixc/analyzer/methods/exclude_validation.rs index b0f63bfe1..0297c1352 100644 --- a/helix-db/src/helixc/analyzer/methods/exclude_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/exclude_validation.rs @@ -8,7 +8,7 @@ use crate::{ fix::Fix, types::Type, }, - parser::{types::*, location::Loc}, + parser::{location::Loc, types::*}, }, }; use paste::paste; @@ -127,7 +127,13 @@ pub(crate) fn validate_exclude<'a>( validate_exclude(ctx, ty, tr, ex, excluded, original_query); } _ => { - generate_error!(ctx, original_query, ex.fields[0].0.clone(), E203, cur_ty.kind_str()); + generate_error!( + ctx, + original_query, + ex.fields[0].0.clone(), + E203, + cur_ty.kind_str() + ); } } } @@ -135,7 +141,7 @@ pub(crate) fn validate_exclude<'a>( #[cfg(test)] mod tests { use crate::helixc::analyzer::error_codes::ErrorCode; - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // Field Exclusion Tests diff --git a/helix-db/src/helixc/analyzer/methods/graph_step_validation.rs b/helix-db/src/helixc/analyzer/methods/graph_step_validation.rs index f453c1a14..7e525c5f3 100644 --- a/helix-db/src/helixc/analyzer/methods/graph_step_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/graph_step_validation.rs @@ -13,12 +13,13 @@ use crate::{ utils::{gen_identifier_or_param, is_valid_identifier}, }, generator::{ - math_functions::{generate_math_expr, ExpressionContext}, + math_functions::{ExpressionContext, generate_math_expr}, queries::Query as GeneratedQuery, traversal_steps::{ FromV as GeneratedFromV, In as GeneratedIn, InE as GeneratedInE, Out as GeneratedOut, OutE as GeneratedOutE, SearchVectorStep, - ShortestPath as GeneratedShortestPath, ShortestPathAStar as GeneratedShortestPathAStar, + ShortestPath as GeneratedShortestPath, + ShortestPathAStar as GeneratedShortestPathAStar, ShortestPathBFS as GeneratedShortestPathBFS, ShortestPathDijkstras as GeneratedShortestPathDijkstras, ShouldCollect, Step as GeneratedStep, ToV as GeneratedToV, Traversal as GeneratedTraversal, @@ -404,9 +405,7 @@ pub(crate) fn apply_graph_step<'a>( Some(WeightExpression::Expression(expr)) => { // Generate Rust code for the math expression match generate_math_expr(expr, ExpressionContext::WeightCalculation) { - Ok(math_expr) => { - WeightCalculation::Expression(format!("{}", math_expr)) - } + Ok(math_expr) => WeightCalculation::Expression(format!("{}", math_expr)), Err(e) => { generate_error!( ctx, @@ -421,9 +420,7 @@ pub(crate) fn apply_graph_step<'a>( } } } - Some(WeightExpression::Default) | None => { - WeightCalculation::Default - } + Some(WeightExpression::Default) | None => WeightCalculation::Default, }; // Extract weight property for validation (if it's a simple property) @@ -555,9 +552,7 @@ pub(crate) fn apply_graph_step<'a>( } Some(WeightExpression::Expression(expr)) => { match generate_math_expr(expr, ExpressionContext::WeightCalculation) { - Ok(math_expr) => { - WeightCalculation::Expression(format!("{}", math_expr)) - } + Ok(math_expr) => WeightCalculation::Expression(format!("{}", math_expr)), Err(e) => { generate_error!( ctx, @@ -579,32 +574,33 @@ pub(crate) fn apply_graph_step<'a>( traversal .steps - .push(Separator::Period(GeneratedStep::ShortestPathAStar( - match (sp.from.clone(), sp.to.clone()) { - (Some(from), Some(to)) => GeneratedShortestPathAStar { - label: type_arg, - from: Some(GenRef::from(from)), - to: Some(GenRef::from(to)), - weight_calculation, - heuristic_property, - }, - (Some(from), None) => GeneratedShortestPathAStar { - label: type_arg, - from: Some(GenRef::from(from)), - to: None, - weight_calculation, - heuristic_property, - }, - (None, Some(to)) => GeneratedShortestPathAStar { - label: type_arg, - from: None, - to: Some(GenRef::from(to)), - weight_calculation, - heuristic_property, - }, - (None, None) => panic!("Invalid shortest path astar"), + .push(Separator::Period(GeneratedStep::ShortestPathAStar(match ( + sp.from.clone(), + sp.to.clone(), + ) { + (Some(from), Some(to)) => GeneratedShortestPathAStar { + label: type_arg, + from: Some(GenRef::from(from)), + to: Some(GenRef::from(to)), + weight_calculation, + heuristic_property, }, - ))); + (Some(from), None) => GeneratedShortestPathAStar { + label: type_arg, + from: Some(GenRef::from(from)), + to: None, + weight_calculation, + heuristic_property, + }, + (None, Some(to)) => GeneratedShortestPathAStar { + label: type_arg, + from: None, + to: Some(GenRef::from(to)), + weight_calculation, + heuristic_property, + }, + (None, None) => panic!("Invalid shortest path astar"), + }))); traversal.should_collect = ShouldCollect::ToVec; Some(Type::Unknown) } diff --git a/helix-db/src/helixc/analyzer/methods/migration_validation.rs b/helix-db/src/helixc/analyzer/methods/migration_validation.rs index 02aa49cc9..7b6d8729a 100644 --- a/helix-db/src/helixc/analyzer/methods/migration_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/migration_validation.rs @@ -8,9 +8,7 @@ use crate::{ }, utils::{GenRef, GeneratedValue, Separator}, }, - parser::types::{ - FieldValueType, Migration, MigrationItem, MigrationPropertyMapping, - }, + parser::types::{FieldValueType, Migration, MigrationItem, MigrationPropertyMapping}, }, protocol::value::Value, }; @@ -26,7 +24,10 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, migration.from_version.0.clone(), ErrorCode::E108, - format!("Migration references non-existent schema version: {}", migration.from_version.1), + format!( + "Migration references non-existent schema version: {}", + migration.from_version.1 + ), Some("Ensure the schema version exists before referencing it in a migration".into()), ); return; @@ -41,7 +42,10 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, migration.to_version.0.clone(), ErrorCode::E108, - format!("Migration references non-existent schema version: {}", migration.to_version.1), + format!( + "Migration references non-existent schema version: {}", + migration.to_version.1 + ), Some("Ensure the schema version exists before referencing it in a migration".into()), ); return; @@ -80,8 +84,13 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, item.from_item.0.clone(), ErrorCode::E201, - format!("Migration item '{item_name}' does not exist in schema version {}", migration.from_version.1), - Some(format!("Ensure '{item_name}' is defined in the source schema")), + format!( + "Migration item '{item_name}' does not exist in schema version {}", + migration.from_version.1 + ), + Some(format!( + "Ensure '{item_name}' is defined in the source schema" + )), ); continue; } @@ -99,8 +108,13 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, item.to_item.0.clone(), ErrorCode::E201, - format!("Migration item '{item_name}' does not exist in schema version {}", migration.to_version.1), - Some(format!("Ensure '{item_name}' is defined in the target schema")), + format!( + "Migration item '{item_name}' does not exist in schema version {}", + migration.to_version.1 + ), + Some(format!( + "Ensure '{item_name}' is defined in the target schema" + )), ); continue; } @@ -113,8 +127,11 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, item.loc.clone(), ErrorCode::E205, - format!("Migration item types do not match: '{}' to '{}'", - item.from_item.1.inner(), item.to_item.1.inner()), + format!( + "Migration item types do not match: '{}' to '{}'", + item.from_item.1.inner(), + item.to_item.1.inner() + ), Some("Migration between different item types is not yet supported".into()), ); continue; @@ -143,9 +160,15 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, property_name.0.clone(), ErrorCode::E202, - format!("Property '{}' does not exist in target schema for '{}'", - property_name.1, item.to_item.1.inner()), - Some(format!("Ensure property '{}' is defined in the target schema", property_name.1)), + format!( + "Property '{}' does not exist in target schema for '{}'", + property_name.1, + item.to_item.1.inner() + ), + Some(format!( + "Ensure property '{}' is defined in the target schema", + property_name.1 + )), ); continue; } @@ -175,9 +198,14 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { ctx, property_value.loc.clone(), ErrorCode::E202, - format!("Identifier '{}' does not exist in source schema for '{}'", - identifier, item.from_item.1.inner()), - Some(format!("Ensure '{identifier}' is a valid field in the source schema")), + format!( + "Identifier '{}' does not exist in source schema for '{}'", + identifier, + item.from_item.1.inner() + ), + Some(format!( + "Ensure '{identifier}' is a valid field in the source schema" + )), ); continue; } @@ -188,7 +216,10 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { property_value.loc.clone(), ErrorCode::E206, "Unsupported property value type in migration".into(), - Some("Only literal values and identifiers are supported in migrations".into()), + Some( + "Only literal values and identifiers are supported in migrations" + .into(), + ), ); continue; } @@ -196,30 +227,42 @@ pub(crate) fn validate_migration(ctx: &mut Ctx, migration: &Migration) { // check default value is valid for the new field type if let Some(default) = &default - && to_property_field.field_type != *default { - push_schema_err( - ctx, - property_value.loc.clone(), - ErrorCode::E205, - format!("Default value type mismatch: expected '{}' but got '{:?}'", - to_property_field.field_type, default), - Some("Ensure the default value type matches the field type in the target schema".into()), - ); - continue; + && to_property_field.field_type != *default + { + push_schema_err( + ctx, + property_value.loc.clone(), + ErrorCode::E205, + format!( + "Default value type mismatch: expected '{}' but got '{:?}'", + to_property_field.field_type, default + ), + Some( + "Ensure the default value type matches the field type in the target schema" + .into(), + ), + ); + continue; } // check the cast is valid for the new field type if let Some(cast) = &cast - && to_property_field.field_type != cast.cast_to { - push_schema_err( - ctx, - cast.loc.clone(), - ErrorCode::E205, - format!("Cast target type mismatch: expected '{}' but got '{}'", - to_property_field.field_type, cast.cast_to), - Some("Ensure the cast target type matches the field type in the target schema".into()), - ); - continue; + && to_property_field.field_type != cast.cast_to + { + push_schema_err( + ctx, + cast.loc.clone(), + ErrorCode::E205, + format!( + "Cast target type mismatch: expected '{}' but got '{}'", + to_property_field.field_type, cast.cast_to + ), + Some( + "Ensure the cast target type matches the field type in the target schema" + .into(), + ), + ); + continue; } // // warnings if name is same diff --git a/helix-db/src/helixc/analyzer/methods/object_validation.rs b/helix-db/src/helixc/analyzer/methods/object_validation.rs index 9635b7891..42f486067 100644 --- a/helix-db/src/helixc/analyzer/methods/object_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/object_validation.rs @@ -345,9 +345,11 @@ fn validate_property_access<'a>( ); // Check if this nested traversal ends with a Closure step - let own_closure_param = tr.steps.last() - .and_then(|step| match &step.step { - crate::helixc::parser::types::StepType::Closure(cl) => Some(cl.identifier.clone()), + let own_closure_param = + tr.steps.last().and_then(|step| match &step.step { + crate::helixc::parser::types::StepType::Closure(cl) => { + Some(cl.identifier.clone()) + } _ => None, }); @@ -386,9 +388,11 @@ fn validate_property_access<'a>( ); // Check if this nested traversal ends with a Closure step - let own_closure_param = tr.steps.last() - .and_then(|step| match &step.step { - crate::helixc::parser::types::StepType::Closure(cl) => Some(cl.identifier.clone()), + let own_closure_param = + tr.steps.last().and_then(|step| match &step.step { + crate::helixc::parser::types::StepType::Closure(cl) => { + Some(cl.identifier.clone()) + } _ => None, }); diff --git a/helix-db/src/helixc/analyzer/methods/statement_validation.rs b/helix-db/src/helixc/analyzer/methods/statement_validation.rs index 76211bd64..b7f0a6554 100644 --- a/helix-db/src/helixc/analyzer/methods/statement_validation.rs +++ b/helix-db/src/helixc/analyzer/methods/statement_validation.rs @@ -5,8 +5,11 @@ use crate::{ generate_error, helixc::{ analyzer::{ - Ctx, errors::push_query_err, methods::infer_expr_type::infer_expr_type, types::Type, - utils::{is_valid_identifier, VariableInfo}, + Ctx, + errors::push_query_err, + methods::infer_expr_type::infer_expr_type, + types::Type, + utils::{VariableInfo, is_valid_identifier}, }, generator::{ queries::Query as GeneratedQuery, @@ -63,14 +66,17 @@ pub(crate) fn validate_statements<'a>( // Determine if the variable is single or collection based on type let is_single = if let Some(GeneratedStatement::Traversal(ref tr)) = stmt { // Check if should_collect is ToObj, or if the type is a single value - matches!(tr.should_collect, ShouldCollect::ToObj) || - matches!(rhs_ty, Type::Node(_) | Type::Edge(_) | Type::Vector(_)) + matches!(tr.should_collect, ShouldCollect::ToObj) + || matches!(rhs_ty, Type::Node(_) | Type::Edge(_) | Type::Vector(_)) } else { // Non-traversal: check if type is single matches!(rhs_ty, Type::Node(_) | Type::Edge(_) | Type::Vector(_)) }; - scope.insert(assign.variable.as_str(), VariableInfo::new(rhs_ty, is_single)); + scope.insert( + assign.variable.as_str(), + VariableInfo::new(rhs_ty, is_single), + ); stmt.as_ref()?; @@ -200,8 +206,14 @@ pub(crate) fn validate_statements<'a>( .unwrap() .clone(), ); - body_scope.insert(field_name.as_str(), VariableInfo::new(field_type.clone(), true)); - scope.insert(field_name.as_str(), VariableInfo::new(field_type, true)); + body_scope.insert( + field_name.as_str(), + VariableInfo::new(field_type.clone(), true), + ); + scope.insert( + field_name.as_str(), + VariableInfo::new(field_type, true), + ); } for_variable = ForVariable::ObjectDestructure( fields @@ -238,18 +250,25 @@ pub(crate) fn validate_statements<'a>( Type::Array(object_arr) => { match object_arr.as_ref() { Type::Object(object) => { - let mut obj_dest_fields = Vec::with_capacity(fields.len()); + let mut obj_dest_fields = + Vec::with_capacity(fields.len()); let object = object.clone(); for (_, field_name) in fields { let name = field_name.as_str(); // adds non-param fields to scope let field_type = object.get(name).unwrap().clone(); - body_scope.insert(name, VariableInfo::new(field_type.clone(), true)); - scope.insert(name, VariableInfo::new(field_type, true)); - obj_dest_fields.push(GenRef::Std(name.to_string())); - } - for_variable = - ForVariable::ObjectDestructure(obj_dest_fields); + body_scope.insert( + name, + VariableInfo::new(field_type.clone(), true), + ); + scope.insert( + name, + VariableInfo::new(field_type, true), + ); + obj_dest_fields.push(GenRef::Std(name.to_string())); + } + for_variable = + ForVariable::ObjectDestructure(obj_dest_fields); } _ => { generate_error!( @@ -273,7 +292,7 @@ pub(crate) fn validate_statements<'a>( [&fl.in_variable.1] ); } - } + }, _ => { generate_error!( ctx, @@ -311,7 +330,7 @@ pub(crate) fn validate_statements<'a>( #[cfg(test)] mod tests { use super::*; - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // Assignment Validation Tests @@ -335,7 +354,11 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E302)); - assert!(diagnostics.iter().any(|d| d.message.contains("previously declared"))); + assert!( + diagnostics + .iter() + .any(|d| d.message.contains("previously declared")) + ); } #[test] @@ -381,7 +404,11 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E301)); - assert!(diagnostics.iter().any(|d| d.message.contains("not in scope") && d.message.contains("unknownList"))); + assert!( + diagnostics + .iter() + .any(|d| d.message.contains("not in scope") && d.message.contains("unknownList")) + ); } #[test] @@ -425,7 +452,11 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); assert!(diagnostics.iter().any(|d| d.error_code == ErrorCode::E651)); - assert!(diagnostics.iter().any(|d| d.message.contains("not iterable"))); + assert!( + diagnostics + .iter() + .any(|d| d.message.contains("not iterable")) + ); } #[test] @@ -514,7 +545,9 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); // Expression statements should not produce errors - assert!(diagnostics.is_empty() || !diagnostics.iter().any(|d| d.error_code == ErrorCode::E301)); + assert!( + diagnostics.is_empty() || !diagnostics.iter().any(|d| d.error_code == ErrorCode::E301) + ); } #[test] @@ -586,7 +619,11 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); - assert!(!diagnostics.iter().any(|d| d.error_code == ErrorCode::E301 || d.error_code == ErrorCode::E302)); + assert!( + !diagnostics + .iter() + .any(|d| d.error_code == ErrorCode::E301 || d.error_code == ErrorCode::E302) + ); } #[test] @@ -609,6 +646,10 @@ mod tests { assert!(result.is_ok()); let (diagnostics, _) = result.unwrap(); - assert!(!diagnostics.iter().any(|d| d.error_code == ErrorCode::E301 || d.error_code == ErrorCode::E302)); + assert!( + !diagnostics + .iter() + .any(|d| d.error_code == ErrorCode::E301 || d.error_code == ErrorCode::E302) + ); } } diff --git a/helix-db/src/helixc/analyzer/mod.rs b/helix-db/src/helixc/analyzer/mod.rs index df891e26e..2260e4d1c 100644 --- a/helix-db/src/helixc/analyzer/mod.rs +++ b/helix-db/src/helixc/analyzer/mod.rs @@ -12,12 +12,15 @@ use crate::helixc::{ methods::{ migration_validation::validate_migration, query_validation::validate_query, - schema_methods::{build_field_lookups, check_schema, SchemaVersionMap}, + schema_methods::{SchemaVersionMap, build_field_lookups, check_schema}, }, types::Type, }, generator::Source as GeneratedSource, - parser::{errors::ParserError, types::{EdgeSchema, ExpressionType, Field, Query, ReturnType, Source}}, + parser::{ + errors::ParserError, + types::{EdgeSchema, ExpressionType, Field, Query, ReturnType, Source}, + }, }; use itertools::Itertools; use serde::Serialize; @@ -44,7 +47,6 @@ pub mod pretty; pub mod types; pub mod utils; - /// Internal working context shared by all passes. pub(crate) struct Ctx<'a> { pub(super) src: &'a Source, diff --git a/helix-db/src/helixc/analyzer/types.rs b/helix-db/src/helixc/analyzer/types.rs index da2b9233d..8a2f9d07b 100644 --- a/helix-db/src/helixc/analyzer/types.rs +++ b/helix-db/src/helixc/analyzer/types.rs @@ -229,10 +229,10 @@ impl From for GeneratedValue { /// Metadata for GROUPBY and AGGREGATE_BY operations #[derive(Debug, Clone)] pub struct AggregateInfo { - pub source_type: Box, // Original type being aggregated (Node, Edge, Vector) - pub properties: Vec, // Properties being grouped by - pub is_count: bool, // true for COUNT mode - pub is_group_by: bool, // true for GROUP_BY, false for AGGREGATE_BY + pub source_type: Box, // Original type being aggregated (Node, Edge, Vector) + pub properties: Vec, // Properties being grouped by + pub is_count: bool, // true for COUNT mode + pub is_group_by: bool, // true for GROUP_BY, false for AGGREGATE_BY } #[derive(Debug, Clone)] @@ -408,7 +408,11 @@ impl From<&FieldType> for Type { String | Boolean | F32 | F64 | I8 | I16 | I32 | I64 | U8 | U16 | U32 | U64 | U128 | Uuid | Date => Type::Scalar(ft.clone()), Array(inner_ft) => Type::Array(Box::new(Type::from(*inner_ft.clone()))), - Object(obj) => Type::Object(obj.iter().map(|(k, v)| (k.clone(), Type::from(v))).collect()), + Object(obj) => Type::Object( + obj.iter() + .map(|(k, v)| (k.clone(), Type::from(v))) + .collect(), + ), Identifier(id) => Type::Scalar(FieldType::Identifier(id.clone())), } } diff --git a/helix-db/src/helixc/generator/migrations.rs b/helix-db/src/helixc/generator/migrations.rs index b1a7b5221..d5ffcd286 100644 --- a/helix-db/src/helixc/generator/migrations.rs +++ b/helix-db/src/helixc/generator/migrations.rs @@ -153,7 +153,9 @@ mod tests { remappings: vec![Separator::Semicolon( GeneratedMigrationPropertyMapping::FieldAdditionFromOldField { old_field: GeneratedValue::Literal(GenRef::Literal("name".to_string())), - new_field: GeneratedValue::Literal(GenRef::Literal("full_name".to_string())), + new_field: GeneratedValue::Literal(GenRef::Literal( + "full_name".to_string(), + )), }, )], should_spread: false, @@ -203,7 +205,9 @@ mod tests { ), Separator::Semicolon( GeneratedMigrationPropertyMapping::FieldAdditionFromValue { - new_field_name: GeneratedValue::Literal(GenRef::Literal("c".to_string())), + new_field_name: GeneratedValue::Literal(GenRef::Literal( + "c".to_string(), + )), new_field_type: FieldType::Boolean, value: GeneratedValue::Primitive(GenRef::Std("true".to_string())), }, diff --git a/helix-db/src/helixc/generator/statements.rs b/helix-db/src/helixc/generator/statements.rs index fdfda778e..828bfa48e 100644 --- a/helix-db/src/helixc/generator/statements.rs +++ b/helix-db/src/helixc/generator/statements.rs @@ -3,8 +3,6 @@ use std::fmt::Display; use crate::helixc::generator::{bool_ops::BoExp, traversal_steps::Traversal, utils::GenRef}; - - #[derive(Clone)] pub enum Statement { Assignment(Assignment), @@ -27,13 +25,20 @@ impl Display for Statement { Statement::Literal(literal) => write!(f, "{literal}"), Statement::Identifier(identifier) => write!(f, "{identifier}"), Statement::BoExp(bo) => write!(f, "{bo}"), - Statement::Array(array) => write!(f, "[{}]", array.iter().map(|s| s.to_string()).collect::>().join(", ")), + Statement::Array(array) => write!( + f, + "[{}]", + array + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(", ") + ), Statement::Empty => write!(f, ""), } } } - #[derive(Clone)] pub enum IdentifierType { Primitive, @@ -197,7 +202,9 @@ mod tests { fn test_assignment_statement() { let assignment = Statement::Assignment(Assignment { variable: GenRef::Std("result".to_string()), - value: Box::new(Statement::Identifier(GenRef::Std("computation".to_string()))), + value: Box::new(Statement::Identifier(GenRef::Std( + "computation".to_string(), + ))), }); let output = format!("{}", assignment); assert!(output.contains("let result = computation")); @@ -227,4 +234,3 @@ mod tests { assert_eq!(var.inner(), ""); } } - diff --git a/helix-db/src/helixc/parser/creation_step_parse_methods.rs b/helix-db/src/helixc/parser/creation_step_parse_methods.rs index 3775dd6b2..6786d69cb 100644 --- a/helix-db/src/helixc/parser/creation_step_parse_methods.rs +++ b/helix-db/src/helixc/parser/creation_step_parse_methods.rs @@ -152,7 +152,7 @@ impl HelixParser { #[cfg(test)] mod tests { - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // AddNode Tests diff --git a/helix-db/src/helixc/parser/errors.rs b/helix-db/src/helixc/parser/errors.rs index 7e019b0c1..7a462e1e5 100644 --- a/helix-db/src/helixc/parser/errors.rs +++ b/helix-db/src/helixc/parser/errors.rs @@ -48,4 +48,4 @@ impl std::fmt::Debug for ParserError { } } } -} \ No newline at end of file +} diff --git a/helix-db/src/helixc/parser/expression_parse_methods.rs b/helix-db/src/helixc/parser/expression_parse_methods.rs index 9d6f8b615..cb3a73bde 100644 --- a/helix-db/src/helixc/parser/expression_parse_methods.rs +++ b/helix-db/src/helixc/parser/expression_parse_methods.rs @@ -541,7 +541,7 @@ impl HelixParser { return Err(ParserError::from(format!( "Unknown mathematical function: {}", function_name - ))) + ))); } }; @@ -590,7 +590,9 @@ impl HelixParser { match inner_inner.as_rule() { Rule::math_function_call => Ok(Expression { loc: inner_inner.loc(), - expr: ExpressionType::MathFunctionCall(self.parse_math_function_call(inner_inner)?), + expr: ExpressionType::MathFunctionCall( + self.parse_math_function_call(inner_inner)?, + ), }), Rule::float => inner_inner .as_str() @@ -614,11 +616,15 @@ impl HelixParser { }), Rule::traversal => Ok(Expression { loc: inner_inner.loc(), - expr: ExpressionType::Traversal(Box::new(self.parse_traversal(inner_inner)?)), + expr: ExpressionType::Traversal(Box::new( + self.parse_traversal(inner_inner)?, + )), }), Rule::id_traversal => Ok(Expression { loc: inner_inner.loc(), - expr: ExpressionType::Traversal(Box::new(self.parse_traversal(inner_inner)?)), + expr: ExpressionType::Traversal(Box::new( + self.parse_traversal(inner_inner)?, + )), }), _ => Err(ParserError::from(format!( "Unexpected evaluates_to_number type: {:?}", @@ -668,7 +674,7 @@ impl HelixParser { #[cfg(test)] mod tests { - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // Literal Expression Tests diff --git a/helix-db/src/helixc/parser/graph_step_parse_methods.rs b/helix-db/src/helixc/parser/graph_step_parse_methods.rs index 0c86d1d91..91ee1f1bc 100644 --- a/helix-db/src/helixc/parser/graph_step_parse_methods.rs +++ b/helix-db/src/helixc/parser/graph_step_parse_methods.rs @@ -2,10 +2,10 @@ use crate::helixc::parser::{ HelixParser, ParserError, Rule, location::HasLoc, types::{ - Aggregate, BooleanOp, BooleanOpType, Closure, Exclude, Expression, ExpressionType, FieldAddition, - FieldValue, FieldValueType, GraphStep, GraphStepType, GroupBy, IdType, MMRDistance, Object, OrderBy, - OrderByType, RerankMMR, RerankRRF, ShortestPath, ShortestPathAStar, ShortestPathBFS, - ShortestPathDijkstras, Step, StepType, Update, + Aggregate, BooleanOp, BooleanOpType, Closure, Exclude, Expression, ExpressionType, + FieldAddition, FieldValue, FieldValueType, GraphStep, GraphStepType, GroupBy, IdType, + MMRDistance, Object, OrderBy, OrderByType, RerankMMR, RerankRRF, ShortestPath, + ShortestPathAStar, ShortestPathBFS, ShortestPathDijkstras, Step, StepType, Update, }, utils::{PairTools, PairsTools}, }; @@ -448,13 +448,8 @@ impl HelixParser { Rule::math_expression => { // Parse the math_expression into an Expression let expr = self.parse_math_expression(p)?; - Ok(( - type_arg, - Some(expr), - from, - to, - )) - }, + Ok((type_arg, Some(expr), from, to)) + } Rule::to_from => match p.into_inner().next() { Some(p) => match p.as_rule() { Rule::to => Ok(( @@ -476,9 +471,7 @@ impl HelixParser { _ => Ok((type_arg, weight_expr, from, to)), }, ) { - Ok((type_arg, weight_expr, from, to)) => { - (type_arg, weight_expr, from, to) - } + Ok((type_arg, weight_expr, from, to)) => (type_arg, weight_expr, from, to), Err(e) => return Err(e), }; @@ -489,18 +482,25 @@ impl HelixParser { ExpressionType::Traversal(_trav) => { // For now, keep the traversal and create a Property weight expression // TODO: Extract property name from traversal for simple cases - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) + Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )) } ExpressionType::MathFunctionCall(_) => { - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) - } - _ => { - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) + Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )) } + _ => Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )), }; (None, weight_type) } else { - (None, Some(crate::helixc::parser::types::WeightExpression::Default)) + ( + None, + Some(crate::helixc::parser::types::WeightExpression::Default), + ) }; GraphStep { @@ -579,7 +579,8 @@ impl HelixParser { for inner_pair in pair.clone().into_inner() { match inner_pair.as_rule() { Rule::type_args => { - type_arg = Some(inner_pair.into_inner().next().unwrap().as_str().to_string()); + type_arg = + Some(inner_pair.into_inner().next().unwrap().as_str().to_string()); } Rule::math_expression => { weight_expression = Some(self.parse_expression(inner_pair)?); @@ -590,15 +591,21 @@ impl HelixParser { heuristic_property = Some(literal[1..literal.len() - 1].to_string()); } Rule::to_from => { - if let Some(p) = inner_pair.into_inner().next() { match p.as_rule() { - Rule::to => { - to = Some(p.into_inner().next().unwrap().as_str().to_string()); - } - Rule::from => { - from = Some(p.into_inner().next().unwrap().as_str().to_string()); + if let Some(p) = inner_pair.into_inner().next() { + match p.as_rule() { + Rule::to => { + to = Some( + p.into_inner().next().unwrap().as_str().to_string(), + ); + } + Rule::from => { + from = Some( + p.into_inner().next().unwrap().as_str().to_string(), + ); + } + _ => {} } - _ => {} - } } + } } _ => {} } @@ -608,18 +615,25 @@ impl HelixParser { let (inner_traversal, weight_expr_typed) = if let Some(expr) = weight_expression { let weight_type = match &expr.expr { ExpressionType::Traversal(_trav) => { - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) + Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )) } ExpressionType::MathFunctionCall(_) => { - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) - } - _ => { - Some(crate::helixc::parser::types::WeightExpression::Expression(Box::new(expr.clone()))) + Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )) } + _ => Some(crate::helixc::parser::types::WeightExpression::Expression( + Box::new(expr.clone()), + )), }; (None, weight_type) } else { - (None, Some(crate::helixc::parser::types::WeightExpression::Default)) + ( + None, + Some(crate::helixc::parser::types::WeightExpression::Default), + ) }; GraphStep { @@ -717,8 +731,13 @@ impl HelixParser { }); } - let lambda = lambda.ok_or_else(|| ParserError::from("lambda parameter required for RerankMMR"))?; + let lambda = + lambda.ok_or_else(|| ParserError::from("lambda parameter required for RerankMMR"))?; - Ok(RerankMMR { loc, lambda, distance }) + Ok(RerankMMR { + loc, + lambda, + distance, + }) } } diff --git a/helix-db/src/helixc/parser/query_parse_methods.rs b/helix-db/src/helixc/parser/query_parse_methods.rs index 30af87711..8767aba9e 100644 --- a/helix-db/src/helixc/parser/query_parse_methods.rs +++ b/helix-db/src/helixc/parser/query_parse_methods.rs @@ -1,7 +1,6 @@ use crate::helixc::parser::{ - HelixParser, Rule, + HelixParser, ParserError, Rule, location::HasLoc, - ParserError, types::{BuiltInMacro, Parameter, Query, Statement, StatementType}, }; use pest::iterators::Pair; @@ -20,14 +19,14 @@ impl HelixParser { let built_in_macro = match pair.into_inner().next() { Some(pair) => match pair.as_rule() { Rule::mcp_macro => Some(BuiltInMacro::MCP), - Rule::model_macro => { - match pair.into_inner().next() { - Some(model_name) => Some(BuiltInMacro::Model( - model_name.as_str().to_string(), - )), - None => return Err(ParserError::from("Model macro missing model name")), + Rule::model_macro => match pair.into_inner().next() { + Some(model_name) => { + Some(BuiltInMacro::Model(model_name.as_str().to_string())) } - } + None => { + return Err(ParserError::from("Model macro missing model name")); + } + }, _ => None, }, _ => None, @@ -37,17 +36,24 @@ impl HelixParser { } _ => None, }; - let name = pairs.next() + let name = pairs + .next() .ok_or_else(|| ParserError::from("Expected query name"))? - .as_str().to_string(); + .as_str() + .to_string(); let parameters = self.parse_parameters( - pairs.next().ok_or_else(|| ParserError::from("Expected parameters block"))? + pairs + .next() + .ok_or_else(|| ParserError::from("Expected parameters block"))?, )?; - let body = pairs.next() + let body = pairs + .next() .ok_or_else(|| ParserError::from("Expected query body"))?; let statements = self.parse_query_body(body)?; let return_values = self.parse_return_statement( - pairs.next().ok_or_else(|| ParserError::from("Expected return statement"))? + pairs + .next() + .ok_or_else(|| ParserError::from("Expected return statement"))?, )?; Ok(Query { @@ -68,7 +74,8 @@ impl HelixParser { .map(|p: Pair<'_, Rule>| -> Result { let mut inner = p.into_inner(); let name = { - let pair = inner.next() + let pair = inner + .next() .ok_or_else(|| ParserError::from("Expected parameter name"))?; (pair.loc(), pair.as_str().to_string()) }; @@ -136,7 +143,9 @@ impl HelixParser { }), Rule::drop => { - let inner = p.into_inner().next() + let inner = p + .into_inner() + .next() .ok_or_else(|| ParserError::from("Drop statement missing expression"))?; Ok(Statement { loc: inner.loc(), @@ -160,7 +169,7 @@ impl HelixParser { #[cfg(test)] mod tests { use super::*; - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // Basic Query Parsing Tests diff --git a/helix-db/src/helixc/parser/return_value_parse_methods.rs b/helix-db/src/helixc/parser/return_value_parse_methods.rs index 13cbd4495..b4a51acbb 100644 --- a/helix-db/src/helixc/parser/return_value_parse_methods.rs +++ b/helix-db/src/helixc/parser/return_value_parse_methods.rs @@ -1,9 +1,8 @@ use std::collections::HashMap; use crate::helixc::parser::{ - HelixParser, Rule, + HelixParser, ParserError, Rule, location::HasLoc, - ParserError, types::{Expression, ExpressionType, ReturnType}, }; use pest::iterators::Pair; diff --git a/helix-db/src/helixc/parser/schema_parse_methods.rs b/helix-db/src/helixc/parser/schema_parse_methods.rs index 7b49182ff..1e3816645 100644 --- a/helix-db/src/helixc/parser/schema_parse_methods.rs +++ b/helix-db/src/helixc/parser/schema_parse_methods.rs @@ -524,7 +524,7 @@ impl HelixParser { #[cfg(test)] mod tests { use super::*; - use crate::helixc::parser::{write_to_temp_file, HelixParser}; + use crate::helixc::parser::{HelixParser, write_to_temp_file}; // ============================================================================ // Node Definition Tests @@ -568,8 +568,14 @@ mod tests { let parsed = result.unwrap(); let schema = parsed.schema.get(&1).unwrap(); - assert!(matches!(schema.node_schemas[0].fields[0].prefix, FieldPrefix::Index)); - assert!(matches!(schema.node_schemas[0].fields[1].prefix, FieldPrefix::Empty)); + assert!(matches!( + schema.node_schemas[0].fields[0].prefix, + FieldPrefix::Index + )); + assert!(matches!( + schema.node_schemas[0].fields[1].prefix, + FieldPrefix::Empty + )); } #[test] @@ -641,8 +647,14 @@ mod tests { let parsed = result.unwrap(); let schema = parsed.schema.get(&1).unwrap(); assert_eq!(schema.node_schemas[0].fields.len(), 2); - assert!(matches!(schema.node_schemas[0].fields[0].field_type, FieldType::Array(_))); - assert!(matches!(schema.node_schemas[0].fields[1].field_type, FieldType::Array(_))); + assert!(matches!( + schema.node_schemas[0].fields[0].field_type, + FieldType::Array(_) + )); + assert!(matches!( + schema.node_schemas[0].fields[1].field_type, + FieldType::Array(_) + )); } #[test] @@ -660,7 +672,10 @@ mod tests { let parsed = result.unwrap(); let schema = parsed.schema.get(&1).unwrap(); assert_eq!(schema.node_schemas[0].fields.len(), 1); - assert!(matches!(schema.node_schemas[0].fields[0].field_type, FieldType::Object(_))); + assert!(matches!( + schema.node_schemas[0].fields[0].field_type, + FieldType::Object(_) + )); } #[test] @@ -967,7 +982,10 @@ mod tests { let parsed = result.unwrap(); let schema = parsed.schema.get(&1).unwrap(); - assert!(matches!(schema.node_schemas[0].fields[0].field_type, FieldType::Array(_))); + assert!(matches!( + schema.node_schemas[0].fields[0].field_type, + FieldType::Array(_) + )); } #[test] diff --git a/helix-db/src/lib.rs b/helix-db/src/lib.rs index 317a9376d..32e76458b 100644 --- a/helix-db/src/lib.rs +++ b/helix-db/src/lib.rs @@ -8,4 +8,4 @@ pub mod utils; use mimalloc::MiMalloc; #[global_allocator] -static GLOBAL: MiMalloc = MiMalloc; \ No newline at end of file +static GLOBAL: MiMalloc = MiMalloc; diff --git a/helix-db/src/protocol/custom_serde/edge_serde.rs b/helix-db/src/protocol/custom_serde/edge_serde.rs index 024d5b1f2..b222b65a7 100644 --- a/helix-db/src/protocol/custom_serde/edge_serde.rs +++ b/helix-db/src/protocol/custom_serde/edge_serde.rs @@ -2,8 +2,8 @@ use crate::utils::{ items::Edge, properties::{ImmutablePropertiesMap, ImmutablePropertiesMapDeSeed}, }; -use std::fmt; use serde::de::{DeserializeSeed, Visitor}; +use std::fmt; /// Helper DeserializeSeed for Option struct OptionPropertiesMapDeSeed<'arena> { diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index 40dfeab3e..4e290f440 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -468,7 +468,7 @@ mod error_handling_tests { #[test] fn test_vector_extreme_version_value() { let arena = Bump::new(); - let id = 012012u128; + let id = 12012u128; let vector = create_arena_vector(&arena, id, "test", 255, false, &[1.0], vec![]); let props_bytes = bincode::serialize(&vector).unwrap(); diff --git a/helix-db/src/protocol/custom_serde/node_serde.rs b/helix-db/src/protocol/custom_serde/node_serde.rs index 9af6bec92..54e5533f1 100644 --- a/helix-db/src/protocol/custom_serde/node_serde.rs +++ b/helix-db/src/protocol/custom_serde/node_serde.rs @@ -1,9 +1,9 @@ -use std::fmt; -use serde::de::{DeserializeSeed, Visitor}; use crate::utils::{ items::Node, properties::{ImmutablePropertiesMap, ImmutablePropertiesMapDeSeed}, }; +use serde::de::{DeserializeSeed, Visitor}; +use std::fmt; /// Helper DeserializeSeed for Option /// This is needed because we can't use next_element::>() with custom DeserializeSeed @@ -84,7 +84,9 @@ impl<'de, 'arena> serde::de::DeserializeSeed<'de> for NodeDeSeed<'arena> { .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; let label = self.arena.alloc_str(label_string); - let version: u8 = seq.next_element()?.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + let version: u8 = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; // Bincode serializes Option as ONE field: 0x00 (None) or 0x01+data (Some) // Use our custom DeserializeSeed that handles the Option wrapper diff --git a/helix-db/src/protocol/custom_serde/tests.rs b/helix-db/src/protocol/custom_serde/tests.rs index e04154428..8c18dd9cb 100644 --- a/helix-db/src/protocol/custom_serde/tests.rs +++ b/helix-db/src/protocol/custom_serde/tests.rs @@ -60,11 +60,7 @@ mod node_serialization_tests { } /// Helper to create an old node with properties - fn create_old_node_with_props( - id: u128, - label: &str, - props: Vec<(&str, Value)>, - ) -> OldNode { + fn create_old_node_with_props(id: u128, label: &str, props: Vec<(&str, Value)>) -> OldNode { if props.is_empty() { OldNode { id, @@ -119,8 +115,10 @@ mod node_serialization_tests { println!("\nByte-by-byte comparison:"); for (i, (old_byte, new_byte)) in old_bytes.iter().zip(new_bytes.iter()).enumerate() { if old_byte != new_byte { - println!(" Index {}: old={:02x} ({}), new={:02x} ({})", - i, old_byte, old_byte, new_byte, new_byte); + println!( + " Index {}: old={:02x} ({}), new={:02x} ({})", + i, old_byte, old_byte, new_byte, new_byte + ); } } @@ -137,7 +135,11 @@ mod node_serialization_tests { if let Err(e) = &deserialized { println!("Deserialization error: {:?}", e); } - assert!(deserialized.is_ok(), "Failed to deserialize new format: {:?}", deserialized.err()); + assert!( + deserialized.is_ok(), + "Failed to deserialize new format: {:?}", + deserialized.err() + ); // Test that new format can deserialize old format println!("Attempting to deserialize old_bytes..."); @@ -146,7 +148,11 @@ mod node_serialization_tests { if let Err(e) = &old_deserialized { println!("Deserialization error from old format: {:?}", e); } - assert!(old_deserialized.is_ok(), "Failed to deserialize old format: {:?}", old_deserialized.err()); + assert!( + old_deserialized.is_ok(), + "Failed to deserialize old format: {:?}", + old_deserialized.err() + ); } #[test] @@ -313,7 +319,10 @@ mod node_serialization_tests { ("u16_val", Value::U16(65535)), ("u32_val", Value::U32(4294967295)), ("u64_val", Value::U64(18446744073709551615)), - ("u128_val", Value::U128(340282366920938463463374607431768211455)), + ( + "u128_val", + Value::U128(340282366920938463463374607431768211455), + ), ("f32_val", Value::F32(3.14159)), ("f64_val", Value::F64(2.71828)), ("bool_val", Value::Boolean(true)), @@ -328,7 +337,10 @@ mod node_serialization_tests { let props = deserialized.properties.unwrap(); assert_eq!(props.len(), 13); - assert_eq!(props.get("string_val"), Some(&Value::String("test".to_string()))); + assert_eq!( + props.get("string_val"), + Some(&Value::String("test".to_string())) + ); assert_eq!(props.get("i8_val"), Some(&Value::I8(-42))); assert_eq!(props.get("i16_val"), Some(&Value::I16(1000))); assert_eq!(props.get("i32_val"), Some(&Value::I32(100000))); @@ -336,8 +348,14 @@ mod node_serialization_tests { assert_eq!(props.get("u8_val"), Some(&Value::U8(255))); assert_eq!(props.get("u16_val"), Some(&Value::U16(65535))); assert_eq!(props.get("u32_val"), Some(&Value::U32(4294967295))); - assert_eq!(props.get("u64_val"), Some(&Value::U64(18446744073709551615))); - assert_eq!(props.get("u128_val"), Some(&Value::U128(340282366920938463463374607431768211455))); + assert_eq!( + props.get("u64_val"), + Some(&Value::U64(18446744073709551615)) + ); + assert_eq!( + props.get("u128_val"), + Some(&Value::U128(340282366920938463463374607431768211455)) + ); assert_eq!(props.get("f32_val"), Some(&Value::F32(3.14159))); assert_eq!(props.get("f64_val"), Some(&Value::F64(2.71828))); assert_eq!(props.get("bool_val"), Some(&Value::Boolean(true))); @@ -349,14 +367,16 @@ mod node_serialization_tests { let id = 22222u128; let props = vec![ - ("array", Value::Array(vec![ - Value::I32(1), - Value::I32(2), - Value::I32(3), - ])), + ( + "array", + Value::Array(vec![Value::I32(1), Value::I32(2), Value::I32(3)]), + ), ("nested_obj", { let mut map = HashMap::new(); - map.insert("inner_key".to_string(), Value::String("inner_value".to_string())); + map.insert( + "inner_key".to_string(), + Value::String("inner_value".to_string()), + ); Value::Object(map) }), ]; @@ -386,9 +406,17 @@ mod node_serialization_tests { // Check that both have the same keys and values (regardless of order) for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); + let new_value = new_props + .get(key) + .unwrap_or_else(|| panic!("Missing key: {}", key)); // For nested objects, we need to compare recursively since HashMap order may differ - assert!(values_equal(old_value, new_value), "Value mismatch for key {}: {:?} != {:?}", key, old_value, new_value); + assert!( + values_equal(old_value, new_value), + "Value mismatch for key {}: {:?} != {:?}", + key, + old_value, + new_value + ); } } @@ -413,7 +441,9 @@ mod node_serialization_tests { a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y)) } (Value::Object(a), Value::Object(b)) => { - a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) + a.len() == b.len() + && a.iter() + .all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) } (Value::Date(a), Value::Date(b)) => a == b, (Value::Id(a), Value::Id(b)) => a == b, @@ -476,7 +506,10 @@ mod node_serialization_tests { let props = deserialized_node.properties.unwrap(); assert_eq!(props.len(), 2); - assert_eq!(props.get("name"), Some(&Value::String("Charlie".to_string()))); + assert_eq!( + props.get("name"), + Some(&Value::String("Charlie".to_string())) + ); assert_eq!(props.get("count"), Some(&Value::U64(42))); } @@ -602,12 +635,7 @@ mod node_serialization_tests { fn test_node_serialization_utf8_labels() { let arena = Bump::new(); - let utf8_labels = ["Hello", - "世界", - "🚀🌟", - "Привет", - "مرحبا", - "Ñoño"]; + let utf8_labels = ["Hello", "世界", "🚀🌟", "Привет", "مرحبا", "Ñoño"]; for (idx, label) in utf8_labels.iter().enumerate() { let id = idx as u128; @@ -619,7 +647,8 @@ mod node_serialization_tests { assert_eq!( old_bytes, new_bytes, - "UTF-8 label '{}' serialization differs", label + "UTF-8 label '{}' serialization differs", + label ); } } @@ -646,7 +675,10 @@ mod node_serialization_tests { assert_eq!(props.len(), 3); assert_eq!(props.get("名前"), Some(&Value::String("太郎".to_string()))); assert_eq!(props.get("возраст"), Some(&Value::I32(25))); - assert_eq!(props.get("emoji_key_🎉"), Some(&Value::String("party_🎊".to_string()))); + assert_eq!( + props.get("emoji_key_🎉"), + Some(&Value::String("party_🎊".to_string())) + ); } #[test] @@ -675,7 +707,12 @@ mod node_serialization_tests { // Verify all properties are present with correct values for i in 0..50 { let key = format!("key_{}", i); - assert_eq!(props.get(&key), Some(&Value::I32(i)), "Missing or incorrect value for {}", key); + assert_eq!( + props.get(&key), + Some(&Value::I32(i)), + "Missing or incorrect value for {}", + key + ); } } @@ -710,10 +747,7 @@ mod node_serialization_tests { let arena = Bump::new(); let id = 13131u128; - let props = vec![ - ("empty_val", Value::Empty), - ("normal_val", Value::I32(42)), - ]; + let props = vec![("empty_val", Value::Empty), ("normal_val", Value::I32(42))]; let new_node = create_arena_node_with_props(&arena, id, "EmptyValue", props); let new_bytes = bincode::serialize(&new_node).unwrap(); @@ -841,7 +875,9 @@ mod edge_serialization_tests { a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y)) } (Value::Object(a), Value::Object(b)) => { - a.len() == b.len() && a.iter().all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) + a.len() == b.len() + && a.iter() + .all(|(k, v)| b.get(k).is_some_and(|bv| values_equal(v, bv))) } (Value::Date(a), Value::Date(b)) => a == b, (Value::Id(a), Value::Id(b)) => a == b, @@ -936,8 +972,14 @@ mod edge_serialization_tests { // Check semantic equality (order may differ) for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); - assert!(values_equal(old_value, new_value), "Value mismatch for key {}", key); + let new_value = new_props + .get(key) + .unwrap_or_else(|| panic!("Missing key: {}", key)); + assert!( + values_equal(old_value, new_value), + "Value mismatch for key {}", + key + ); } } @@ -953,7 +995,8 @@ mod edge_serialization_tests { ("verified", Value::Boolean(true)), ]; - let original = create_arena_edge_with_props(&arena, id, "RELATED_TO", from_node, to_node, props); + let original = + create_arena_edge_with_props(&arena, id, "RELATED_TO", from_node, to_node, props); let bytes = bincode::serialize(&original).unwrap(); let arena2 = Bump::new(); @@ -1001,7 +1044,10 @@ mod edge_serialization_tests { let props = deserialized.properties.unwrap(); assert_eq!(props.len(), 2); assert_eq!(props.get("strength"), Some(&Value::I32(5))); - assert_eq!(props.get("label_text"), Some(&Value::String("connection".to_string()))); + assert_eq!( + props.get("label_text"), + Some(&Value::String("connection".to_string())) + ); } #[test] @@ -1014,18 +1060,25 @@ mod edge_serialization_tests { let props = vec![ ("metadata", { let mut map = HashMap::new(); - map.insert("created_by".to_string(), Value::String("system".to_string())); + map.insert( + "created_by".to_string(), + Value::String("system".to_string()), + ); map.insert("timestamp".to_string(), Value::I64(1234567890)); Value::Object(map) }), - ("tags", Value::Array(vec![ - Value::String("important".to_string()), - Value::String("verified".to_string()), - ])), + ( + "tags", + Value::Array(vec![ + Value::String("important".to_string()), + Value::String("verified".to_string()), + ]), + ), ]; let old_edge = create_old_edge_with_props(id, "HAS_TAG", from_node, to_node, props.clone()); - let new_edge = create_arena_edge_with_props(&arena, id, "HAS_TAG", from_node, to_node, props); + let new_edge = + create_arena_edge_with_props(&arena, id, "HAS_TAG", from_node, to_node, props); let old_bytes = bincode::serialize(&old_edge).unwrap(); let new_bytes = bincode::serialize(&new_edge).unwrap(); @@ -1043,8 +1096,16 @@ mod edge_serialization_tests { // Compare nested values for (key, old_value) in old_props.iter() { - let new_value = new_props.get(key).unwrap_or_else(|| panic!("Missing key: {}", key)); - assert!(values_equal(old_value, new_value), "Value mismatch for key {}: {:?} != {:?}", key, old_value, new_value); + let new_value = new_props + .get(key) + .unwrap_or_else(|| panic!("Missing key: {}", key)); + assert!( + values_equal(old_value, new_value), + "Value mismatch for key {}: {:?} != {:?}", + key, + old_value, + new_value + ); } } @@ -1077,7 +1138,12 @@ mod edge_serialization_tests { // Verify all properties are present for i in 0..20 { let key = format!("prop_{}", i); - assert_eq!(props.get(&key), Some(&Value::I32(i)), "Property {} mismatch", key); + assert_eq!( + props.get(&key), + Some(&Value::I32(i)), + "Property {} mismatch", + key + ); } } @@ -1117,8 +1183,10 @@ mod edge_serialization_tests { println!("\nByte-by-byte comparison:"); for (i, (old_byte, new_byte)) in old_bytes.iter().zip(new_bytes.iter()).enumerate() { if old_byte != new_byte { - println!(" Index {}: old={:02x} ({}), new={:02x} ({})", - i, old_byte, old_byte, new_byte, new_byte); + println!( + " Index {}: old={:02x} ({}), new={:02x} ({})", + i, old_byte, old_byte, new_byte, new_byte + ); } } @@ -1141,7 +1209,8 @@ mod edge_serialization_tests { ("emoji", Value::String("🔗".to_string())), ]; - let new_edge = create_arena_edge_with_props(&arena, id, "繋がり", from_node, to_node, props); + let new_edge = + create_arena_edge_with_props(&arena, id, "繋がり", from_node, to_node, props); let bytes = bincode::serialize(&new_edge).unwrap(); let arena2 = Bump::new(); diff --git a/helix-db/src/protocol/format.rs b/helix-db/src/protocol/format.rs index fe3c96123..e23f9798d 100644 --- a/helix-db/src/protocol/format.rs +++ b/helix-db/src/protocol/format.rs @@ -201,7 +201,8 @@ mod tests { fn test_format_deserialize_invalid_json() { let invalid_json = b"not valid json {"; - let result: Result, GraphError> = Format::Json.deserialize(invalid_json); + let result: Result, GraphError> = + Format::Json.deserialize(invalid_json); assert!(result.is_err()); if let Err(GraphError::DecodeError(msg)) = result { diff --git a/helix-db/src/protocol/mod.rs b/helix-db/src/protocol/mod.rs index b1ee482b1..22b127df9 100644 --- a/helix-db/src/protocol/mod.rs +++ b/helix-db/src/protocol/mod.rs @@ -1,9 +1,9 @@ +pub mod custom_serde; pub mod date; pub mod error; pub mod format; pub mod request; pub mod response; -pub mod custom_serde; pub mod value; pub use error::HelixError; diff --git a/helix-db/src/protocol/request.rs b/helix-db/src/protocol/request.rs index c0d84f3ce..d194567ae 100644 --- a/helix-db/src/protocol/request.rs +++ b/helix-db/src/protocol/request.rs @@ -320,7 +320,10 @@ mod tests { out_fmt: Format::Json, }; - assert_ne!(request1.api_key_hash.unwrap(), request2.api_key_hash.unwrap()); + assert_ne!( + request1.api_key_hash.unwrap(), + request2.api_key_hash.unwrap() + ); } #[test] diff --git a/helix-db/src/utils/id.rs b/helix-db/src/utils/id.rs index 3bdfcf578..6f16de264 100644 --- a/helix-db/src/utils/id.rs +++ b/helix-db/src/utils/id.rs @@ -121,12 +121,12 @@ pub fn v6_uuid() -> u128 { uuid::Uuid::now_v6(&[1, 2, 3, 4, 5, 6]).as_u128() } -/// Converts a uuid to a string slice using a buffer created in the arena -/// +/// Converts a uuid to a string slice using a buffer created in the arena +/// /// This is more efficient that using the `to_string` on the created uuid /// as it avoids formatting and potential double buffering -/// -/// NOTE: This could be optimized further by reusing a slice at a set index within the arena +/// +/// NOTE: This could be optimized further by reusing a slice at a set index within the arena #[inline] pub fn uuid_str(id: u128, arena: &bumpalo::Bump) -> &str { let uuid = uuid::Uuid::from_u128(id); @@ -134,13 +134,17 @@ pub fn uuid_str(id: u128, arena: &bumpalo::Bump) -> &str { uuid.as_hyphenated().encode_lower(buffer) } -/// Converts a uuid to a string slice using a buffer -/// +/// Converts a uuid to a string slice using a buffer +/// /// This is more efficient that using the `to_string` on the created uuid /// as it avoids formatting and potential double buffering #[inline] pub fn uuid_str_from_buf(id: u128, buffer: &mut [u8]) -> &str { - assert_eq!(buffer.len(), 36, "length of hyphenated buffer should be 36 characters long"); + assert_eq!( + buffer.len(), + 36, + "length of hyphenated buffer should be 36 characters long" + ); let uuid = uuid::Uuid::from_u128(id); uuid.as_hyphenated().encode_lower(buffer) } diff --git a/helix-db/src/utils/items.rs b/helix-db/src/utils/items.rs index 594b7d2b4..29fed1da7 100644 --- a/helix-db/src/utils/items.rs +++ b/helix-db/src/utils/items.rs @@ -47,11 +47,13 @@ impl<'arena> serde::Serialize for Node<'arena> { if serializer.is_human_readable() { // Include id for JSON serialization let mut buffer = [0u8; 36]; - let mut state = serializer.serialize_map(Some(3 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0)))?; + let mut state = serializer.serialize_map(Some( + 3 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0), + ))?; state.serialize_entry("id", uuid_str_from_buf(self.id, &mut buffer))?; - state.serialize_entry("label", self.label)?; + state.serialize_entry("label", self.label)?; state.serialize_entry("version", &self.version)?; - if let Some(properties ) = &self.properties { + if let Some(properties) = &self.properties { for (key, value) in properties.iter() { state.serialize_entry(key, value)?; } @@ -177,7 +179,9 @@ impl<'arena> serde::Serialize for Edge<'arena> { if serializer.is_human_readable() { // Include id for JSON serialization let mut buffer = [0u8; 36]; - let mut state = serializer.serialize_map(Some(5 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0)))?; + let mut state = serializer.serialize_map(Some( + 5 + self.properties.as_ref().map(|p| p.len()).unwrap_or(0), + ))?; state.serialize_entry("id", uuid_str_from_buf(self.id, &mut buffer))?; state.serialize_entry("label", self.label)?; state.serialize_entry("version", &self.version)?; diff --git a/helix-db/src/utils/label_hash.rs b/helix-db/src/utils/label_hash.rs index bacc79820..d93a080d8 100644 --- a/helix-db/src/utils/label_hash.rs +++ b/helix-db/src/utils/label_hash.rs @@ -31,7 +31,10 @@ mod tests { let hash_person = hash_label("person", None); let hash_company = hash_label("company", None); - assert_ne!(hash_person, hash_company, "Different labels should produce different hashes"); + assert_ne!( + hash_person, hash_company, + "Different labels should produce different hashes" + ); } #[test] @@ -42,22 +45,24 @@ mod tests { let hash_seed_42 = hash_label(label, Some(42)); // Same label with no seed vs seed 0 should be same - assert_eq!(hash_no_seed, hash_seed_0, "No seed should be equivalent to seed 0"); + assert_eq!( + hash_no_seed, hash_seed_0, + "No seed should be equivalent to seed 0" + ); // Different seed should produce different hash - assert_ne!(hash_no_seed, hash_seed_42, "Different seeds should produce different hashes"); + assert_ne!( + hash_no_seed, hash_seed_42, + "Different seeds should produce different hashes" + ); } #[test] fn test_hash_label_collision_rate() { // Test collision rate with 10,000 labels - let labels: Vec = (0..10_000) - .map(|i| format!("label_{}", i)) - .collect(); + let labels: Vec = (0..10_000).map(|i| format!("label_{}", i)).collect(); - let hashes: HashSet<[u8; 4]> = labels.iter() - .map(|l| hash_label(l, None)) - .collect(); + let hashes: HashSet<[u8; 4]> = labels.iter().map(|l| hash_label(l, None)).collect(); let collision_rate = 1.0 - (hashes.len() as f64 / labels.len() as f64); @@ -86,16 +91,14 @@ mod tests { // Test with UTF-8 characters let labels = vec![ "person", - "人", // Chinese character - "🚀", // Emoji - "Ñoño", // Spanish with tildes - "Привет", // Russian - "مرحبا", // Arabic + "人", // Chinese character + "🚀", // Emoji + "Ñoño", // Spanish with tildes + "Привет", // Russian + "مرحبا", // Arabic ]; - let hashes: Vec<[u8; 4]> = labels.iter() - .map(|l| hash_label(l, None)) - .collect(); + let hashes: Vec<[u8; 4]> = labels.iter().map(|l| hash_label(l, None)).collect(); // All should be different let unique_hashes: HashSet<_> = hashes.iter().collect(); @@ -149,15 +152,9 @@ mod tests { #[test] fn test_hash_label_similar_strings() { // Test labels that differ by only one character - let labels = ["person", - "persons", - "person1", - "person_", - "Person"]; + let labels = ["person", "persons", "person1", "person_", "Person"]; - let hashes: Vec<[u8; 4]> = labels.iter() - .map(|l| hash_label(l, None)) - .collect(); + let hashes: Vec<[u8; 4]> = labels.iter().map(|l| hash_label(l, None)).collect(); // All should be different let unique_hashes: HashSet<_> = hashes.iter().collect(); @@ -177,7 +174,10 @@ mod tests { // Should be big-endian bytes (we can convert back) let value = u32::from_be_bytes(hash); - assert!(value > 0, "Hash value should be non-zero for non-empty string"); + assert!( + value > 0, + "Hash value should be non-zero for non-empty string" + ); } #[test] @@ -216,9 +216,7 @@ mod tests { "created_by", ]; - let hashes: HashSet<[u8; 4]> = common_labels.iter() - .map(|l| hash_label(l, None)) - .collect(); + let hashes: HashSet<[u8; 4]> = common_labels.iter().map(|l| hash_label(l, None)).collect(); // All common labels should hash uniquely assert_eq!( diff --git a/helix-db/src/utils/tqdm.rs b/helix-db/src/utils/tqdm.rs index 8067309d7..19954a990 100644 --- a/helix-db/src/utils/tqdm.rs +++ b/helix-db/src/utils/tqdm.rs @@ -1,6 +1,6 @@ use std::{ - io::{stdout, Write}, fmt, + io::{Write, stdout}, }; pub enum ProgChar { @@ -12,7 +12,8 @@ impl fmt::Display for ProgChar { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let c = match self { ProgChar::Block => '█', - ProgChar::Hash => '#', }; + ProgChar::Hash => '#', + }; write!(f, "{c}") } } diff --git a/helix-macros/src/lib.rs b/helix-macros/src/lib.rs index 095ee3ad7..3342cc4ad 100644 --- a/helix-macros/src/lib.rs +++ b/helix-macros/src/lib.rs @@ -5,7 +5,10 @@ extern crate syn; use proc_macro::TokenStream; use quote::quote; use syn::{ - parse::{Parse, ParseStream}, parse_macro_input, Data, DeriveInput, Expr, FnArg, Ident, ItemFn, ItemStruct, ItemTrait, LitInt, Pat, Stmt, Token, TraitItem + Data, DeriveInput, Expr, FnArg, Ident, ItemFn, ItemStruct, ItemTrait, LitInt, Pat, Stmt, Token, + TraitItem, + parse::{Parse, ParseStream}, + parse_macro_input, }; #[proc_macro_attribute] @@ -98,7 +101,6 @@ pub fn get_handler(_attr: TokenStream, item: TokenStream) -> TokenStream { expanded.into() } - #[proc_macro_attribute] pub fn tool_calls(_attr: TokenStream, input: TokenStream) -> TokenStream { let input_trait = parse_macro_input!(input as ItemTrait); @@ -372,20 +374,17 @@ pub fn traversable_derive(input: TokenStream) -> TokenStream { // Verify that the struct has an 'id' field let has_id_field = match &input.data { - Data::Struct(data) => { - data.fields.iter().any(|field| { - field.ident.as_ref().map(|i| i == "id").unwrap_or(false) - }) - } + Data::Struct(data) => data + .fields + .iter() + .any(|field| field.ident.as_ref().map(|i| i == "id").unwrap_or(false)), _ => false, }; if !has_id_field { - return TokenStream::from( - quote! { - compile_error!("Traversable can only be derived for structs with an 'id: &'a str' field"); - } - ); + return TokenStream::from(quote! { + compile_error!("Traversable can only be derived for structs with an 'id: &'a str' field"); + }); } // Extract lifetime parameter if present diff --git a/hql-tests/src/main.rs b/hql-tests/src/main.rs index 9e26618cb..adafa6ab1 100644 --- a/hql-tests/src/main.rs +++ b/hql-tests/src/main.rs @@ -66,6 +66,7 @@ async fn check_issue_exists(github_config: &GitHubConfig, error_hash: &str) -> R } #[allow(unused)] +#[allow(clippy::too_many_arguments)] async fn create_github_issue( github_config: &GitHubConfig, error_type: &str, @@ -454,7 +455,9 @@ async fn main() -> Result<()> { ); } - println!("[SUCCESS] Finished processing batch {current_batch}/{total_batches} successfully"); + println!( + "[SUCCESS] Finished processing batch {current_batch}/{total_batches} successfully" + ); } else { // Process all test directories in parallel (default behavior) println!( @@ -525,7 +528,6 @@ async fn process_test_directory( return Ok(()); } - // Find the query file - could be queries.hx or file*.hx let mut query_file_path = None; let schema_hx_path = folder_path.join("schema.hx"); @@ -643,8 +645,9 @@ async fn process_test_directory( let stderr = String::from_utf8_lossy(&output.stderr); let stdout = String::from_utf8_lossy(&output.stdout); // For helix compilation, we'll show the raw output since it's not cargo format - let error_message = - format!("[FAILED] HELIX COMPILE FAILED for {test_name}\nStderr: {stderr}\nStdout: {stdout}"); + let error_message = format!( + "[FAILED] HELIX COMPILE FAILED for {test_name}\nStderr: {stderr}\nStdout: {stdout}" + ); // Create GitHub issue if configuration is available if let Some(config) = github_config { diff --git a/metrics/src/events.rs b/metrics/src/events.rs index 27cd145d8..0d47203f1 100644 --- a/metrics/src/events.rs +++ b/metrics/src/events.rs @@ -248,4 +248,4 @@ pub struct InvalidApiKeyEvent { #[serde(skip_serializing_if = "Option::is_none")] pub cluster_id: Option, pub time_taken_usec: u32, -} \ No newline at end of file +} diff --git a/metrics/src/lib.rs b/metrics/src/lib.rs index 48dadd54f..067bd1a08 100644 --- a/metrics/src/lib.rs +++ b/metrics/src/lib.rs @@ -602,7 +602,6 @@ mod tests { // Channel should have fewer or equal batches let _final_count = METRICS_STATE.events_rx.len(); - } } @@ -691,5 +690,4 @@ mod tests { assert!(json_str.starts_with('[')); assert!(json_str.ends_with(']')); } - } From 6f596d42e66f99d5eb59b182fd19704ee59e3ec2 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sat, 22 Nov 2025 13:53:38 -0300 Subject: [PATCH 40/48] Do not build madvise dependency for windows --- helix-db/Cargo.toml | 4 +++- .../src/helix_engine/vector_core/reader.rs | 16 ++++++++------ .../src/helixc/generator/math_functions.rs | 22 +++++++++---------- .../src/helixc/generator/traversal_steps.rs | 8 +++---- hql-tests/tests/basic_search_v/queries.hx | 4 ++-- hql-tests/tests/benchmarks/queries.hx | 10 ++++----- .../tests/brute_force_search_v/queries.hx | 9 +------- hql-tests/tests/cognee/queries.hx | 6 ++--- hql-tests/tests/companies_graph/queries.hx | 16 +++++++------- hql-tests/tests/companies_graph_v2/queries.hx | 16 +++++++------- .../queries.hx | 6 ++--- hql-tests/tests/date_comparisons/queries.hx | 6 ++--- .../tests/dijkstra_custom_weights/schema.hx | 4 ++-- .../tests/edge_from_node_to_vec/queries.hx | 5 ++--- hql-tests/tests/graphiti/queries.hx | 4 ++-- hql-tests/tests/graphiti/schema.hx | 8 +++---- hql-tests/tests/knowledge_graphs/queries.hx | 16 +++++++------- hql-tests/tests/knowledge_graphs/schema.hx | 2 +- hql-tests/tests/model_macro/schema.hx | 4 ++-- .../tests/multi_type_index_test/queries.hx | 12 +++++----- hql-tests/tests/nested_for_loops/queries.hx | 8 +++---- hql-tests/tests/putts_professor/queries.hx | 12 +++++----- hql-tests/tests/rerankers/queries.hx | 16 +++++++------- .../search_v_as_assignment_and_expr/file8.hx | 2 +- hql-tests/tests/series/queries.hx | 6 ++--- .../tests/update_drop_then_add/file52.hx | 4 ++-- 26 files changed, 111 insertions(+), 115 deletions(-) diff --git a/helix-db/Cargo.toml b/helix-db/Cargo.toml index 37907ae11..37f019022 100644 --- a/helix-db/Cargo.toml +++ b/helix-db/Cargo.toml @@ -61,10 +61,12 @@ tinyvec = "1.10.0" papaya = "0.2.3" hashbrown = "0.16.0" min-max-heap = "1.3.0" -madvise = "0.1.0" page_size = "0.6.0" rustc-hash = "2.1.1" +[target.'cfg(not(windows))'.dependencies] +madvise = "0.1.0" + [dev-dependencies] rand = "0.9.0" lazy_static = "1.4.0" diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index 85466fccb..3cb736576 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -6,11 +6,9 @@ use std::num::NonZeroUsize; use bumpalo::collections::CollectIn; use hashbrown::HashMap; use heed3::RoTxn; -use heed3::types::Bytes; use heed3::types::DecodeIgnore; use min_max_heap::MinMaxHeap; use roaring::RoaringBitmap; -use tracing::warn; use crate::helix_engine::vector_core::VectorCoreResult; use crate::helix_engine::vector_core::VectorError; @@ -19,16 +17,20 @@ use crate::helix_engine::vector_core::distance::DistanceValue; use crate::helix_engine::vector_core::hnsw::ScoredLink; use crate::helix_engine::vector_core::item_iter::ItemIter; use crate::helix_engine::vector_core::key::{Key, KeyCodec, Prefix, PrefixCodec}; -#[cfg(not(windows))] -use crate::helix_engine::vector_core::metadata::Metadata; -use crate::helix_engine::vector_core::metadata::MetadataCodec; +use crate::helix_engine::vector_core::metadata::{Metadata, MetadataCodec}; use crate::helix_engine::vector_core::node::Node; use crate::helix_engine::vector_core::node::{Item, Links}; use crate::helix_engine::vector_core::ordered_float::OrderedFloat; -use crate::helix_engine::vector_core::unaligned_vector::{UnalignedVector, VectorCodec}; +use crate::helix_engine::vector_core::unaligned_vector::UnalignedVector; use crate::helix_engine::vector_core::version::{Version, VersionCodec}; use crate::helix_engine::vector_core::{CoreDatabase, ItemId}; +#[cfg(not(windows))] +use { + crate::helix_engine::vector_core::unaligned_vector::VectorCodec, heed3::types::Bytes, + tracing::warn, +}; + /// A good default value for the `ef` parameter. const DEFAULT_EF_SEARCH: usize = 100; @@ -320,7 +322,7 @@ impl Reader { _database: &CoreDatabase, _index: u16, _metadata: &Metadata, - ) -> Result<()> { + ) -> VectorCoreResult<()> { // madvise crate does not support windows. Ok(()) } diff --git a/helix-db/src/helixc/generator/math_functions.rs b/helix-db/src/helixc/generator/math_functions.rs index 3aade7204..4aa72d7c5 100644 --- a/helix-db/src/helixc/generator/math_functions.rs +++ b/helix-db/src/helixc/generator/math_functions.rs @@ -68,28 +68,28 @@ impl Display for PropertyAccess { PropertyContext::Edge => { write!( f, - "(edge.get_property({}).ok_or(GraphError::Default)?.as_f32())", + "(edge.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property ) } PropertyContext::SourceNode => { write!( f, - "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", + "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property ) } PropertyContext::TargetNode => { write!( f, - "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", + "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property ) } PropertyContext::Current => { write!( f, - "(v.get_property({}).ok_or(GraphError::Default)?.as_f32())", + "(v.get_property({}).ok_or(GraphError::Default)?.as_f64())", self.property ) } @@ -469,7 +469,7 @@ mod tests { }; assert_eq!( edge_prop.to_string(), - "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32())" + "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64())" ); // Test SourceNode context @@ -479,7 +479,7 @@ mod tests { }; assert_eq!( src_prop.to_string(), - "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32())" + "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64())" ); // Test TargetNode context @@ -489,14 +489,14 @@ mod tests { }; assert_eq!( dst_prop.to_string(), - "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f32())" + "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f64())" ); } #[test] fn test_complex_weight_expression() { // Test: MUL(_::{distance}, POW(0.95, DIV(_::{days}, 30))) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f32()) / 30_f32))) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (0.95_f32).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f64()) / 30_f32))) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -525,14 +525,14 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f32()) / 30_f32)))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (0.95_f32).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f64()) / 30_f32)))" ); } #[test] fn test_multi_context_expression() { // Test: MUL(_::{distance}, _::From::{traffic_factor}) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f32())) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f64())) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -549,7 +549,7 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32()))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64()))" ); } } diff --git a/helix-db/src/helixc/generator/traversal_steps.rs b/helix-db/src/helixc/generator/traversal_steps.rs index 356c3f513..92a55c7c0 100644 --- a/helix-db/src/helixc/generator/traversal_steps.rs +++ b/helix-db/src/helixc/generator/traversal_steps.rs @@ -724,14 +724,14 @@ impl Display for ShortestPathDijkstras { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", expr )?; } @@ -786,14 +786,14 @@ impl Display for ShortestPathAStar { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}, ", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}, ", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", expr )?; } diff --git a/hql-tests/tests/basic_search_v/queries.hx b/hql-tests/tests/basic_search_v/queries.hx index c2e853873..65e3dac6d 100644 --- a/hql-tests/tests/basic_search_v/queries.hx +++ b/hql-tests/tests/basic_search_v/queries.hx @@ -13,7 +13,7 @@ E::EdgeUser { } -QUERY user(vec: [F64]) => +QUERY user(vec: [F32]) => vecs <- SearchV(vec, 10) // pre_filter <- SearchV(vec, 10)::PREFILTER(_::{content}::EQ("hello")) RETURN "hello" @@ -32,4 +32,4 @@ V::Document { QUERY SearchText(query: String, limit: I64) => // Search for documents that are similar to the query results <- SearchV(Embed(query), limit) - RETURN results \ No newline at end of file + RETURN results diff --git a/hql-tests/tests/benchmarks/queries.hx b/hql-tests/tests/benchmarks/queries.hx index b0188dc5b..2ecb036ff 100644 --- a/hql-tests/tests/benchmarks/queries.hx +++ b/hql-tests/tests/benchmarks/queries.hx @@ -33,8 +33,8 @@ QUERY InsertUser(country: U8) => // Query 2: Insert an Item vector node // Creates a new Item record with embedding and category -// The embedding parameter is explicit as an array of F64 values -QUERY InsertItem(embedding: [F64], category: U16) => +// The embedding parameter is explicit as an array of F32 values +QUERY InsertItem(embedding: [F32], category: U16) => item <- AddV(embedding, { category: category }) @@ -87,10 +87,10 @@ QUERY OneHopFilter(user_id: ID, category: U16) => items <- N(user_id)::Out::WHERE(_::{category}::EQ(category)) RETURN items::{id, category} -QUERY Vector(vector: [F64], top_k: I64) => +QUERY Vector(vector: [F32], top_k: I64) => items <- SearchV(vector, top_k) RETURN items::{id, score, category} -QUERY VectorHopFilter(vector: [F64], top_k: I64, country: U8) => +QUERY VectorHopFilter(vector: [F32], top_k: I64, country: U8) => items <- SearchV(vector, top_k)::WHERE(EXISTS(_::In::WHERE(_::{country}::EQ(country)))) - RETURN items::{id, category} \ No newline at end of file + RETURN items::{id, category} diff --git a/hql-tests/tests/brute_force_search_v/queries.hx b/hql-tests/tests/brute_force_search_v/queries.hx index a31462abc..fa192444c 100644 --- a/hql-tests/tests/brute_force_search_v/queries.hx +++ b/hql-tests/tests/brute_force_search_v/queries.hx @@ -17,13 +17,6 @@ E::Friend { To: User, } -QUERY search_vector(query: [f64], k: I64) => +QUERY search_vector(query: [f32], k: I64) => result <- N::Out::SearchV(query, k) RETURN result - - - - - - - diff --git a/hql-tests/tests/cognee/queries.hx b/hql-tests/tests/cognee/queries.hx index 026e4aeb2..518630eba 100644 --- a/hql-tests/tests/cognee/queries.hx +++ b/hql-tests/tests/cognee/queries.hx @@ -8,7 +8,7 @@ QUERY CogneeHasCollection (collection_name: String) => RETURN {collection: collection} // Add multiple vectors to a collection with a given data points -QUERY CogneeCreateDataPoints (collection_name: String, data_points: [{vector: [F64], dp_id: String, payload: String, content: String}]) => +QUERY CogneeCreateDataPoints (collection_name: String, data_points: [{vector: [F32], dp_id: String, payload: String, content: String}]) => FOR {vector, dp_id, payload, content} IN data_points { AddV(vector, {collection_name: collection_name, data_point_id: dp_id, payload: payload, content: content}) } @@ -20,7 +20,7 @@ QUERY CogneeRetrieve (collection_name: String, dp_ids: [String]) => RETURN {documents: documents} // Perform a search in the specified collection using a vector. -QUERY CogneeSearch (collection_name: String, vector: [F64], limit: I64) => +QUERY CogneeSearch (collection_name: String, vector: [F32], limit: I64) => result <- SearchV(vector, limit)::WHERE(_::{collection_name}::EQ(collection_name)) RETURN {result: result} @@ -133,7 +133,7 @@ QUERY CogneeDeleteGraph () => // Get the target node and its entire neighborhood QUERY CogneeGetConnections (node_id: String) => main_node <- N({node_id: node_id}) - + in_nodes <- main_node::In in_edges <- main_node::InE diff --git a/hql-tests/tests/companies_graph/queries.hx b/hql-tests/tests/companies_graph/queries.hx index 8d251432c..a8b7cbdc7 100644 --- a/hql-tests/tests/companies_graph/queries.hx +++ b/hql-tests/tests/companies_graph/queries.hx @@ -9,7 +9,7 @@ QUERY GetCompany(company_number: String) => QUERY AddCompany(company_number: String, total_filings: I32) => company <- AddN({ - company_number: company_number, + company_number: company_number, total_filings: total_filings, ingested_filings: 0 }) @@ -29,7 +29,7 @@ QUERY DeleteCompany(company_number: String) => // ------------------------------ EDGE OPERATIONS -------------------------- -QUERY GetDocumentEdges(company_number: String) => +QUERY GetDocumentEdges(company_number: String) => c <- N({company_number: company_number}) edges <- c::OutE count <- c::Out::COUNT @@ -41,9 +41,9 @@ QUERY GetDocumentEdges(company_number: String) => // ─── filing / embedding helpers ─────────────────────────────── QUERY AddEmbeddingsToCompany( - company_number: String, + company_number: String, embeddings_data: [{ - vector: [F64], + vector: [F32], text: String, chunk_id: String, page_number: I32, @@ -90,18 +90,18 @@ QUERY GetAllCompanyEmbeddings(company_number: String) => // return vector data RETURN embeddings -QUERY CompanyEmbeddingSearch(company_number: String, query: [F64], k: I32) => +QUERY CompanyEmbeddingSearch(company_number: String, query: [F32], k: I32) => c <- N({company_number: company_number})::OutE::ToV embedding_search <- c::SearchV(query, k) RETURN embedding_search // ---------------------- FOR TESTING --------------------------------- // tmp function for testing helix -QUERY AddVector(vector: [F64], text: String, chunk_id: String, page_number: I32, reference: String) => +QUERY AddVector(vector: [F32], text: String, chunk_id: String, page_number: I32, reference: String) => embedding <- AddV(vector, {text: text, chunk_id: chunk_id, page_number: page_number, reference: reference}) RETURN embedding // tmp function for testing helix -QUERY SearchVector(query: [F64], k: I32) => +QUERY SearchVector(query: [F32], k: I32) => embedding_search <- SearchV(query, k) - RETURN embedding_search \ No newline at end of file + RETURN embedding_search diff --git a/hql-tests/tests/companies_graph_v2/queries.hx b/hql-tests/tests/companies_graph_v2/queries.hx index b5c17c1f3..5e11f557c 100644 --- a/hql-tests/tests/companies_graph_v2/queries.hx +++ b/hql-tests/tests/companies_graph_v2/queries.hx @@ -10,7 +10,7 @@ QUERY GetCompany(company_number: String) => QUERY CreateCompany(company_name: String, company_number: String, total_docs: I32) => company <- AddN({ company_name: company_name, - company_number: company_number, + company_number: company_number, total_docs: total_docs, ingested_docs: 0 }) @@ -30,7 +30,7 @@ QUERY DeleteCompany(company_number: String) => // ------------------------------ EDGE OPERATIONS -------------------------- -QUERY GetDocumentEdges(company_number: String) => +QUERY GetDocumentEdges(company_number: String) => c <- N({company_number: company_number}) edges <- c::OutE RETURN edges @@ -39,9 +39,9 @@ QUERY GetDocumentEdges(company_number: String) => // ─── filing / embedding helpers ─────────────────────────────── QUERY AddEmbeddingsToCompany( - company_number: String, + company_number: String, embeddings_data: [{ - vector: [F64], + vector: [F32], text: String, chunk_id: String, page_number: I32, @@ -85,19 +85,19 @@ QUERY GetAllCompanyEmbeddings(company_number: String) => embeddings <- c::Out RETURN embeddings -QUERY CompanyEmbeddingSearch(company_number: String, query: [F64], k: I32) => +QUERY CompanyEmbeddingSearch(company_number: String, query: [F32], k: I32) => c <- N({company_number: company_number})::OutE::ToV embedding_search <- c::SearchV(query, k) RETURN embedding_search // ---------------------- FOR TESTING --------------------------------- // tmp function for testing helix -QUERY AddVector(vector: [F64], text: String, chunk_id: String, page_number: I32, reference: String) => +QUERY AddVector(vector: [F32], text: String, chunk_id: String, page_number: I32, reference: String) => embedding <- AddV(vector, {text: text, chunk_id: chunk_id, page_number: page_number, reference: reference}) RETURN embedding // tmp function for testing helix -QUERY SearchVector(query: [F64], k: I32) => +QUERY SearchVector(query: [F32], k: I32) => embedding_search <- SearchV(query, k) RETURN embedding_search @@ -122,4 +122,4 @@ QUERY GetVectorsBySourceLinkAndPageRange(company_number: String, source_link: St _::{source_link}::EQ(source_link) ) ) - RETURN vectors \ No newline at end of file + RETURN vectors diff --git a/hql-tests/tests/complete_vector_addition_and_search/queries.hx b/hql-tests/tests/complete_vector_addition_and_search/queries.hx index 8d24a5d28..816b92a20 100644 --- a/hql-tests/tests/complete_vector_addition_and_search/queries.hx +++ b/hql-tests/tests/complete_vector_addition_and_search/queries.hx @@ -1,16 +1,16 @@ -QUERY addEmbedding(vec: [F64]) => +QUERY addEmbedding(vec: [F32]) => doc <- AddN({content: "Hello, content!", number: 1}) embedding <- AddV(vec, {chunk: "Hello, chunk!", chunk_id: 1, number: 1, reference: "Hello, reference!"}) AddE::From(doc)::To(embedding) RETURN embedding -QUERY getAllEmbedding() => +QUERY getAllEmbedding() => c <- N({number: 1}) embeddings <- c::Out RETURN embeddings -QUERY searchEmbedding(query: [F64]) => +QUERY searchEmbedding(query: [F32]) => c <- N({number: 1}) embedding_search <- SearchV(query, 10) RETURN embedding_search::{ diff --git a/hql-tests/tests/date_comparisons/queries.hx b/hql-tests/tests/date_comparisons/queries.hx index af5a67933..e42be9ea2 100644 --- a/hql-tests/tests/date_comparisons/queries.hx +++ b/hql-tests/tests/date_comparisons/queries.hx @@ -1,8 +1,8 @@ -QUERY SearchRecentDocuments (vector: [F64], limit: I64, cutoff_date: Date) => +QUERY SearchRecentDocuments (vector: [F32], limit: I64, cutoff_date: Date) => documents <- SearchV(vector, limit)::WHERE(_::{created_at}::GTE(cutoff_date)) RETURN documents -QUERY InsertVector (vector: [F64], content: String, created_at: Date) => +QUERY InsertVector (vector: [F32], content: String, created_at: Date) => document <- AddV(vector, { content: content, created_at: created_at }) doc <- document::{content, created_at} - RETURN document \ No newline at end of file + RETURN document diff --git a/hql-tests/tests/dijkstra_custom_weights/schema.hx b/hql-tests/tests/dijkstra_custom_weights/schema.hx index 28ec1eb83..878cb49c1 100644 --- a/hql-tests/tests/dijkstra_custom_weights/schema.hx +++ b/hql-tests/tests/dijkstra_custom_weights/schema.hx @@ -1,7 +1,7 @@ N::Location { name: String, - traffic_factor: F64, - popularity: F64 + traffic_factor: F32, + popularity: F32 } E::Route { diff --git a/hql-tests/tests/edge_from_node_to_vec/queries.hx b/hql-tests/tests/edge_from_node_to_vec/queries.hx index 94dac69b6..dccde6e98 100644 --- a/hql-tests/tests/edge_from_node_to_vec/queries.hx +++ b/hql-tests/tests/edge_from_node_to_vec/queries.hx @@ -14,7 +14,7 @@ E::EmbeddingOf { } } -QUERY add(vec: [F64]) => +QUERY add(vec: [F32]) => user <- AddN({ name: "John Doe" }) @@ -24,11 +24,10 @@ QUERY add(vec: [F64]) => AddE({category: "test"})::From(user)::To(embedding) RETURN user -QUERY to_v(query: [F64], k: I32, data: String) => +QUERY to_v(query: [F32], k: I32, data: String) => user <- N({name: "John Doe"}) edges <- user::OutE filtered <- edges::WHERE(_::{category}::EQ(data)) vectors <- filtered::ToV searched <- vectors::SearchV(query, k) RETURN user, edges, filtered, vectors, searched - diff --git a/hql-tests/tests/graphiti/queries.hx b/hql-tests/tests/graphiti/queries.hx index 7fb8513d8..378a27e86 100644 --- a/hql-tests/tests/graphiti/queries.hx +++ b/hql-tests/tests/graphiti/queries.hx @@ -2,13 +2,13 @@ // Entity // ######################################################### -QUERY createEntity (name: String, name_embedding: [F64], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => +QUERY createEntity (name: String, name_embedding: [F32], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => entity <- AddN({name: name, group_id: group_id, summary: summary, created_at: created_at, labels: labels, attributes: attributes}) embedding <- AddV(name_embedding, {name_embedding: name_embedding}) edge <- AddE({group_id: group_id})::From(entity)::To(embedding) RETURN entity -QUERY updateEntity (entity_id: ID, name: String, name_embedding: [F64], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => +QUERY updateEntity (entity_id: ID, name: String, name_embedding: [F32], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => entity <- N(entity_id)::UPDATE({name: name, group_id: group_id, summary: summary, created_at: created_at, labels: labels, attributes: attributes}) DROP N(entity_id)::Out embedding <- AddV(name_embedding, {name_embedding: name_embedding}) diff --git a/hql-tests/tests/graphiti/schema.hx b/hql-tests/tests/graphiti/schema.hx index 3cfb1c54f..d523ef9ab 100644 --- a/hql-tests/tests/graphiti/schema.hx +++ b/hql-tests/tests/graphiti/schema.hx @@ -20,7 +20,7 @@ E::Entity_to_Embedding { } V::Entity_Embedding { - name_embedding: [F64], + name_embedding: [F32], } E::Entity_Fact { @@ -60,7 +60,7 @@ E::Fact_to_Embedding { } V::Fact_Embedding { - fact: [F64], + fact: [F32], } E::Fact_Entity { @@ -118,7 +118,7 @@ E::Community_to_Embedding { } V::Community_Embedding { - name_embedding: [F64], + name_embedding: [F32], } E::Community_Entity { @@ -146,4 +146,4 @@ E::Community_Fact { group_id: String, created_at: Date DEFAULT NOW } -} \ No newline at end of file +} diff --git a/hql-tests/tests/knowledge_graphs/queries.hx b/hql-tests/tests/knowledge_graphs/queries.hx index 3ec2d9c55..471eea31d 100644 --- a/hql-tests/tests/knowledge_graphs/queries.hx +++ b/hql-tests/tests/knowledge_graphs/queries.hx @@ -138,7 +138,7 @@ QUERY insert_event_Cluster1 ( uuid: String, chunk_uuid: String, statement: String, - embedding: [F64], + embedding: [F32], triplets: [String], statement_type: String, temporal_type: String, @@ -189,7 +189,7 @@ QUERY update_event_Cluster1 ( uuid: String, chunk_uuid: String, statement: String, - embedding: [F64], + embedding: [F32], triplets: [String], statement_type: String, temporal_type: String, @@ -414,7 +414,7 @@ QUERY remove_entity_Cluster1 ( // ######################################################### QUERY vector_search_events_Cluster1 ( - query_embedding: [F64], + query_embedding: [F32], k: I32 ) => matching_embeddings <- SearchV(query_embedding, k) @@ -440,7 +440,7 @@ QUERY get_stories_mentioning_entity_as_subject_Cluster1 ( stories <- chunks::In RETURN stories, chunks, events, triplets -// Find stories that mention a specific entity (as object) +// Find stories that mention a specific entity (as object) QUERY get_stories_mentioning_entity_as_object_Cluster1 ( entity_uuid: String ) => @@ -613,7 +613,7 @@ QUERY get_sub_comments_by_parent_uuid_Cluster2 ( // Story Embedding operations QUERY add_story_embedding_Cluster2 ( story_uuid: String, - embedding: [F64], + embedding: [F32], content: String ) => story <- N({uuid: story_uuid}) @@ -624,7 +624,7 @@ QUERY add_story_embedding_Cluster2 ( // Comment Embedding operations QUERY add_comment_embedding_Cluster2 ( comment_uuid: String, - embedding: [F64], + embedding: [F32], content: String ) => comment <- N({uuid: comment_uuid}) @@ -633,7 +633,7 @@ QUERY add_comment_embedding_Cluster2 ( RETURN comment QUERY search_similar_stories_Cluster2 ( - query_embedding: [F64], + query_embedding: [F32], k: I64 ) => matching_embeddings <- SearchV(query_embedding, k) @@ -700,4 +700,4 @@ QUERY drop_all_Cluster2 () => DROP N DROP N::Out DROP N - RETURN "Success" \ No newline at end of file + RETURN "Success" diff --git a/hql-tests/tests/knowledge_graphs/schema.hx b/hql-tests/tests/knowledge_graphs/schema.hx index b4e5836f4..ba523aee7 100644 --- a/hql-tests/tests/knowledge_graphs/schema.hx +++ b/hql-tests/tests/knowledge_graphs/schema.hx @@ -77,7 +77,7 @@ E::Event_to_Embedding_Cluster1 { } V::EventEmbedding_Cluster1 { - embedding: [F64] + embedding: [F32] } diff --git a/hql-tests/tests/model_macro/schema.hx b/hql-tests/tests/model_macro/schema.hx index f20b9d604..3d97d6ae8 100644 --- a/hql-tests/tests/model_macro/schema.hx +++ b/hql-tests/tests/model_macro/schema.hx @@ -1,6 +1,6 @@ schema::1 { V::ClinicalNote { - vector: [F64], + vector: [F32], text: String, } -} \ No newline at end of file +} diff --git a/hql-tests/tests/multi_type_index_test/queries.hx b/hql-tests/tests/multi_type_index_test/queries.hx index 35b1e42ee..314bd8fe9 100644 --- a/hql-tests/tests/multi_type_index_test/queries.hx +++ b/hql-tests/tests/multi_type_index_test/queries.hx @@ -9,11 +9,11 @@ QUERY testString(value: String) => QUERY testI8(value: I8) => node <- N({i8_field: value}) RETURN node - + QUERY testI32(value: I32) => node <- N({i32_field: value}) RETURN node - + QUERY testI64(value: I64) => node <- N({i64_field: value}) RETURN node @@ -22,11 +22,11 @@ QUERY testI64(value: I64) => QUERY testU8(value: U8) => node <- N({u8_field: value}) RETURN node - + QUERY testU32(value: U32) => node <- N({u32_field: value}) RETURN node - + QUERY testU64(value: U64) => node <- N({u64_field: value}) RETURN node @@ -35,7 +35,7 @@ QUERY testU64(value: U64) => QUERY testF32(value: F32) => node <- N({f32_field: value}) RETURN node - + QUERY testF64(value: F64) => node <- N({f64_field: value}) RETURN node @@ -62,4 +62,4 @@ QUERY testMultipleConditions(name: String, age: U32, active: Boolean) => nodes_by_name <- N({str_field: name}) nodes_by_age <- N({u32_field: age}) nodes_by_active <- N({bool_field: active}) - RETURN nodes_by_name, nodes_by_age, nodes_by_active \ No newline at end of file + RETURN nodes_by_name, nodes_by_age, nodes_by_active diff --git a/hql-tests/tests/nested_for_loops/queries.hx b/hql-tests/tests/nested_for_loops/queries.hx index 3e0e29013..9234c4068 100644 --- a/hql-tests/tests/nested_for_loops/queries.hx +++ b/hql-tests/tests/nested_for_loops/queries.hx @@ -1,4 +1,4 @@ -QUERY loaddocs_rag(chapters: [{ id: I64, subchapters: [{ title: String, content: String, chunks: [{chunk: String, vector: [F64]}]}] }]) => +QUERY loaddocs_rag(chapters: [{ id: I64, subchapters: [{ title: String, content: String, chunks: [{chunk: String, vector: [F32]}]}] }]) => FOR {id, subchapters} IN chapters { chapter_node <- AddN({ chapter_index: id }) FOR {title, content, chunks} IN subchapters { @@ -12,11 +12,11 @@ QUERY loaddocs_rag(chapters: [{ id: I64, subchapters: [{ title: String, content: } RETURN "Success" -QUERY searchdocs_rag(query: [F64], k: I32) => +QUERY searchdocs_rag(query: [F32], k: I32) => vecs <- SearchV(query, k) subchapters <- vecs::In RETURN subchapters::{title, content} -QUERY edge_node(id: ID) => +QUERY edge_node(id: ID) => e <- N::OutE - RETURN e \ No newline at end of file + RETURN e diff --git a/hql-tests/tests/putts_professor/queries.hx b/hql-tests/tests/putts_professor/queries.hx index aae93f805..856577226 100644 --- a/hql-tests/tests/putts_professor/queries.hx +++ b/hql-tests/tests/putts_professor/queries.hx @@ -45,7 +45,7 @@ QUERY link_professor_to_lab(professor_id: ID, lab_id: ID) => lab <- N(lab_id) edge <- AddE::From(professor)::To(lab) RETURN edge - + // Link Professor to Research Area QUERY link_professor_to_research_area(professor_id: ID, research_area_id: ID) => professor <- N(professor_id) @@ -54,7 +54,7 @@ QUERY link_professor_to_research_area(professor_id: ID, research_area_id: ID) => RETURN edge // Search Similar Professors based on Research Area + Description Embedding -QUERY search_similar_professors_by_research_area_and_description(query_vector: [F64], k: I64) => +QUERY search_similar_professors_by_research_area_and_description(query_vector: [F32], k: I64) => vecs <- SearchV(query_vector, k) professors <- vecs::In RETURN professors @@ -64,14 +64,14 @@ QUERY get_professor_research_areas_with_descriptions(professor_id: ID) => research_areas <- N(professor_id)::Out::{areas_and_descriptions: areas_and_descriptions} RETURN research_areas -QUERY create_research_area_embedding(professor_id: ID, areas_and_descriptions: String, vector: [F64]) => +QUERY create_research_area_embedding(professor_id: ID, areas_and_descriptions: String, vector: [F32]) => professor <- N(professor_id) research_area <- AddV(vector, { areas_and_descriptions: areas_and_descriptions }) edge <- AddE::From(professor)::To(research_area) RETURN research_area -// GET Queries // +// GET Queries // QUERY get_professors_by_university_name(university_name: String) => professors <- N::Out::WHERE(_::{name}::EQ(university_name)) @@ -80,7 +80,7 @@ QUERY get_professors_by_university_name(university_name: String) => QUERY get_professor_by_research_area_name(research_area_name: String) => professors <- N::Out::WHERE(_::{research_area}::EQ(research_area_name)) RETURN professors - + QUERY get_professors_by_department_name(department_name: String) => professors <- N::Out::WHERE(_::{name}::EQ(department_name)) - RETURN professors \ No newline at end of file + RETURN professors diff --git a/hql-tests/tests/rerankers/queries.hx b/hql-tests/tests/rerankers/queries.hx index a3caa8e73..152fb10be 100644 --- a/hql-tests/tests/rerankers/queries.hx +++ b/hql-tests/tests/rerankers/queries.hx @@ -9,42 +9,42 @@ N::Article { } // Test 1: RerankRRF with default k -QUERY testRRFDefault(query_vec: [F64]) => +QUERY testRRFDefault(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankRRF ::RANGE(0, 10) RETURN results // Test 2: RerankRRF with custom k parameter -QUERY testRRFCustomK(query_vec: [F64], k_val: F64) => +QUERY testRRFCustomK(query_vec: [F32], k_val: F64) => results <- SearchV(query_vec, 100) ::RerankRRF(k: k_val) ::RANGE(0, 10) RETURN results // Test 3: RerankMMR with default distance (cosine) -QUERY testMMRDefault(query_vec: [F64]) => +QUERY testMMRDefault(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: 0.7) ::RANGE(0, 10) RETURN results // Test 4: RerankMMR with euclidean distance -QUERY testMMREuclidean(query_vec: [F64]) => +QUERY testMMREuclidean(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: 0.5, distance: "euclidean") ::RANGE(0, 10) RETURN results // Test 5: RerankMMR with dot product distance -QUERY testMMRDotProduct(query_vec: [F64]) => +QUERY testMMRDotProduct(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: 0.6, distance: "dotproduct") ::RANGE(0, 10) RETURN results // Test 6: Chained rerankers (RRF then MMR) -QUERY testChainedRerankers(query_vec: [F64]) => +QUERY testChainedRerankers(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankRRF(k: 60) ::RerankMMR(lambda: 0.7) @@ -52,14 +52,14 @@ QUERY testChainedRerankers(query_vec: [F64]) => RETURN results // Test 7: MMR with variable lambda -QUERY testMMRVariableLambda(query_vec: [F64], lambda_val: F64) => +QUERY testMMRVariableLambda(query_vec: [F32], lambda_val: F64) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: lambda_val) ::RANGE(0, 10) RETURN results // Test 8: Multiple chained MMR rerankers -QUERY testMultipleMMR(query_vec: [F64]) => +QUERY testMultipleMMR(query_vec: [F32]) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: 0.9) ::RerankMMR(lambda: 0.5) diff --git a/hql-tests/tests/search_v_as_assignment_and_expr/file8.hx b/hql-tests/tests/search_v_as_assignment_and_expr/file8.hx index 32321d2e4..be7c42cf2 100644 --- a/hql-tests/tests/search_v_as_assignment_and_expr/file8.hx +++ b/hql-tests/tests/search_v_as_assignment_and_expr/file8.hx @@ -13,7 +13,7 @@ E::EdgeFile8 { } -QUERY file8(vec: [F64]) => +QUERY file8(vec: [F32]) => new_vec <- AddV(vec) AddV(vec) RETURN new_vec diff --git a/hql-tests/tests/series/queries.hx b/hql-tests/tests/series/queries.hx index ba8e77573..0319e3c40 100644 --- a/hql-tests/tests/series/queries.hx +++ b/hql-tests/tests/series/queries.hx @@ -109,7 +109,7 @@ QUERY addWarmConnect (user_id: ID, warm_connect_id: ID) => metadata_to_warm_connect <- AddE()::From(metadata)::To(warm_connect) RETURN warm_connect -QUERY createUserBio (user_id: ID, bio: [F64]) => +QUERY createUserBio (user_id: ID, bio: [F32]) => user_bio <- AddV(bio) user <- N(user_id) user_user_bio <- AddE()::From(user)::To(user_bio) @@ -156,7 +156,7 @@ QUERY getUsersByReferrer(referrer: String) => RETURN users #[mcp] -QUERY searchUsersByBio(bio_vector: [F64], k: I64) => +QUERY searchUsersByBio(bio_vector: [F32], k: I64) => similar_bios <- SearchV(bio_vector, k) users <- similar_bios::In RETURN users @@ -402,4 +402,4 @@ QUERY deleteUser (user_id: ID) => DROP N(user_id)::Out DROP N(user_id)::OutE DROP N(user_id) - RETURN "success" \ No newline at end of file + RETURN "success" diff --git a/hql-tests/tests/update_drop_then_add/file52.hx b/hql-tests/tests/update_drop_then_add/file52.hx index c441fc899..5ded7ea8d 100644 --- a/hql-tests/tests/update_drop_then_add/file52.hx +++ b/hql-tests/tests/update_drop_then_add/file52.hx @@ -1,7 +1,7 @@ -QUERY updateEntity (entity_id: ID, name: String, name_embedding: [F64], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => +QUERY updateEntity (entity_id: ID, name: String, name_embedding: [F32], group_id: String, summary: String, created_at: Date, labels: [String], attributes: String) => entity <- N(entity_id)::UPDATE({name: name, group_id: group_id, summary: summary, created_at: created_at, labels: labels, attributes: attributes}) DROP N(entity_id)::Out DROP N(entity_id)::OutE embedding <- AddV(name_embedding, {name_embedding: name_embedding}) edge <- AddE({group_id: group_id})::From(entity)::To(embedding) - RETURN entity \ No newline at end of file + RETURN entity From 3e9c0e061cdbaf9a3eacfee3377803127f862b0b Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sun, 23 Nov 2025 23:37:16 -0300 Subject: [PATCH 41/48] Do not normalize cosine distance --- helix-db/src/helix_engine/vector_core/distance/cosine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helix-db/src/helix_engine/vector_core/distance/cosine.rs b/helix-db/src/helix_engine/vector_core/distance/cosine.rs index cdb1fa51f..76d2ca1da 100644 --- a/helix-db/src/helix_engine/vector_core/distance/cosine.rs +++ b/helix-db/src/helix_engine/vector_core/distance/cosine.rs @@ -53,7 +53,7 @@ impl Distance for Cosine { // cos = 0. -> 0.5 // cos = -1. -> 1.0 // cos = 1. -> 0.0 - (1.0 - cos) / 2.0 + 1.0 - cos } else { 0.0 } From 7ce0613f71b2873e29255f37df82eb83045853ad Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Sun, 23 Nov 2025 23:37:33 -0300 Subject: [PATCH 42/48] Fix some broken tests --- .../tests/traversal_tests/vector_traversal_tests.rs | 2 +- helix-db/src/protocol/custom_serde/error_handling_tests.rs | 6 +++--- helix-db/src/protocol/custom_serde/integration_tests.rs | 4 ++-- hql-tests/tests/graphiti/queries.hx | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs index 2101c3abd..dd42e380e 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs @@ -263,7 +263,7 @@ fn test_v_from_type_without_vector_data() { assert_eq!(results.len(), 1); assert_eq!(results[0].id(), vector_id); - // Verify it's a VectorWithoutData + // Verify it's a Vector with no data match &results[0] { crate::helix_engine::traversal_core::traversal_value::TraversalValue::Vector(v) => { assert_eq!(v.id, vector_id); diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index 4e290f440..93b11dc5a 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -224,7 +224,7 @@ mod error_handling_tests { } #[test] - #[should_panic(expected = "raw_vector_data.len() == 0")] + #[should_panic] fn test_vector_cast_empty_raw_data_panics() { let arena = Bump::new(); let empty_data: &[u8] = &[]; @@ -260,10 +260,10 @@ mod error_handling_tests { } #[test] - #[should_panic(expected = "is not a multiple of size_of::()")] + #[should_panic(expected = "is not a multiple of size_of::()")] fn test_vector_misaligned_data_bytes_panics() { let arena = Bump::new(); - // 7 bytes is not a multiple of 8 (size of f64) + // 7 bytes is not a multiple of 4 (size of f32) let misaligned: &[u8] = &[0, 1, 2, 3, 4, 5, 6]; HVector::raw_vector_data_to_vec(misaligned, &arena); } diff --git a/helix-db/src/protocol/custom_serde/integration_tests.rs b/helix-db/src/protocol/custom_serde/integration_tests.rs index 33a6ae312..58b82b4fd 100644 --- a/helix-db/src/protocol/custom_serde/integration_tests.rs +++ b/helix-db/src/protocol/custom_serde/integration_tests.rs @@ -635,7 +635,7 @@ mod integration_tests { let vector = create_simple_vector(&arena, id, "test", &data); let data_bytes = vector.vector_data_to_bytes().unwrap(); - // Should be exactly 128 * 8 bytes (128 f64 values) - assert_eq!(data_bytes.len(), 128 * 8); + // Should be exactly 128 * 8 bytes (128 f32 values) + assert_eq!(data_bytes.len(), 128 * 4); } } diff --git a/hql-tests/tests/graphiti/queries.hx b/hql-tests/tests/graphiti/queries.hx index 378a27e86..bee842f21 100644 --- a/hql-tests/tests/graphiti/queries.hx +++ b/hql-tests/tests/graphiti/queries.hx @@ -113,13 +113,13 @@ QUERY deleteEpisodeEdge (episodeEdge_id: ID) => // Community // ######################################################### -QUERY createCommunity (name: String, group_id: String, summary: String, created_at: Date, labels: [String], name_embedding: [F64]) => +QUERY createCommunity (name: String, group_id: String, summary: String, created_at: Date, labels: [String], name_embedding: [F32]) => community <- AddN({name: name, group_id: group_id, summary: summary, created_at: created_at, labels: labels}) embedding <- AddV(name_embedding, {name_embedding: name_embedding}) edge <- AddE({group_id: group_id})::From(community)::To(embedding) RETURN community -QUERY updateCommunity (community_id: ID, name: String, group_id: String, summary: String, created_at: Date, labels: [String], name_embedding: [F64]) => +QUERY updateCommunity (community_id: ID, name: String, group_id: String, summary: String, created_at: Date, labels: [String], name_embedding: [F32]) => community <- N(community_id)::UPDATE({name: name, group_id: group_id, summary: summary, created_at: created_at, labels: labels}) DROP N(community_id)::Out embedding <- AddV(name_embedding, {name_embedding: name_embedding}) From 14591c85f0dfa566aed6b476c8671a779b900e6e Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 09:40:10 -0300 Subject: [PATCH 43/48] Change in memory local to global id hashmap to persistent database --- helix-db/src/helix_engine/bm25/bm25.rs | 3 +- .../traversal_core/ops/vectors/search.rs | 22 ++-- helix-db/src/helix_engine/vector_core/mod.rs | 105 +++++++++++++----- .../src/helix_engine/vector_core/reader.rs | 8 -- 4 files changed, 95 insertions(+), 43 deletions(-) diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 88ca28474..7ca41bc6c 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -426,8 +426,7 @@ impl HybridSearch for HelixGraphStorage { false, &arena, )?; - let scores = - results.into_global_id(&self.vectors.local_to_global_id.read().unwrap()); + let scores = self.vectors.into_global_id(&txn, &results)?; Ok(Some(scores)) }); diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index c81be45ad..386b43f27 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -61,18 +61,24 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE let iter = match vectors { Ok(vectors) => { - let hvectors = self.storage.vectors.nns_to_hvectors( + match self.storage.vectors.nns_to_hvectors( self.txn, vectors.into_nns(), false, self.arena, - ); - - hvectors - .into_iter() - .map(|vector| Ok::(TraversalValue::Vector(vector))) - .collect::>() - .into_iter() + ) { + Ok(hvectors) => hvectors + .into_iter() + .map(|vector| { + Ok::(TraversalValue::Vector(vector)) + }) + .collect::>() + .into_iter(), + Err(err) => { + let error = GraphError::VectorError(format!("{err}")); + once(Err(error)).collect::>().into_iter() + } + } } Err(VectorError::VectorNotFound(id)) => { let error = GraphError::VectorError(format!("vector not found for id {id}")); diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index d23806910..6ea05c679 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -11,7 +11,7 @@ use byteorder::BE; use hashbrown::HashMap; use heed3::{ Database, Env, Error as LmdbError, RoTxn, RwTxn, - types::{Bytes, U128}, + types::{Bytes, U32, U128}, }; use rand::{SeedableRng, rngs::StdRng}; use serde::{Deserialize, Serialize, Serializer, ser::SerializeMap}; @@ -55,6 +55,7 @@ pub mod writer; const DB_VECTORS: &str = "vectors"; // for vector data (v:) const DB_VECTOR_DATA: &str = "vector_data"; // for vector's properties +const DB_ID_MAP: &str = "id_map"; // for map ids pub type ItemId = u32; @@ -336,7 +337,7 @@ pub struct VectorCore { /// Maps global id (u128) to internal id (u32) and label pub global_to_local_id: RwLock>, - pub local_to_global_id: RwLock>, + pub local_to_global_id: Database, U128>, curr_id: AtomicU32, } @@ -349,15 +350,21 @@ impl VectorCore { .name(DB_VECTOR_DATA) .create(txn)?; + let local_to_global_id = env + .database_options() + .types::, U128>() + .name(DB_ID_MAP) + .create(txn)?; + Ok(Self { hsnw: vectors_db, vector_properties_db, config, + local_to_global_id, label_to_index: RwLock::new(HashMap::new()), curr_index: AtomicU16::new(0), global_to_local_id: RwLock::new(HashMap::new()), curr_id: AtomicU32::new(0), - local_to_global_id: RwLock::new(HashMap::new()), }) } @@ -434,15 +441,12 @@ impl VectorCore { bump_vec.extend_from_slice(data); let hvector = HVector::from_vec(label, bump_vec); - let idx = self.curr_id.fetch_add(1, atomic::Ordering::SeqCst); self.global_to_local_id .write() .unwrap() .insert(hvector.id, (idx, label.to_string())); - self.local_to_global_id - .write() - .unwrap() - .insert(idx, hvector.id); + self.local_to_global_id.put(txn, &idx, &hvector.id)?; + self.label_to_index .read() .unwrap() @@ -471,7 +475,11 @@ impl VectorCore { .expect("if index exist label should also exist"); let writer = Writer::new(self.hsnw, index.id, index.dimension); writer.del_item(txn, idx)?; - index.num_vectors.fetch_add(1, atomic::Ordering::SeqCst); + + // TODO: do we actually need to delete here? + self.local_to_global_id.delete(txn, &idx)?; + + index.num_vectors.fetch_sub(1, atomic::Ordering::SeqCst); Ok(()) } None => Err(VectorError::VectorNotFound(format!( @@ -487,7 +495,7 @@ impl VectorCore { nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>, with_data: bool, arena: &'arena bumpalo::Bump, - ) -> bumpalo::collections::Vec<'arena, HVector<'arena>> + ) -> VectorCoreResult>> where 'txn: 'arena, { @@ -496,22 +504,27 @@ impl VectorCore { arena, ); - let local_to_global_id = self.local_to_global_id.read().unwrap(); let label_to_index = self.label_to_index.read().unwrap(); let global_to_local_id = self.global_to_local_id.read().unwrap(); let (item_id, _) = nns.first().unwrap(); - let global_id = local_to_global_id.get(item_id).unwrap(); - let (_, label) = global_to_local_id.get(global_id).unwrap(); + let global_id = self + .local_to_global_id + .get(txn, &item_id)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + let (_, label) = global_to_local_id.get(&global_id).unwrap(); let index = label_to_index.get(label).unwrap(); let label = arena.alloc_str(label); if with_data { for (item_id, distance) in nns.into_iter() { - let global_id = local_to_global_id.get(&item_id).unwrap(); + let global_id = self + .local_to_global_id + .get(txn, &item_id)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; results.push(HVector { - id: *global_id, + id: global_id, distance: Some(distance), label, deleted: false, @@ -523,10 +536,13 @@ impl VectorCore { } } else { for (item_id, distance) in nns.into_iter() { - let global_id = local_to_global_id.get(&item_id).unwrap(); + let global_id = self + .local_to_global_id + .get(txn, &item_id)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; results.push(HVector { - id: *global_id, + id: global_id, distance: Some(distance), label, deleted: false, @@ -538,7 +554,7 @@ impl VectorCore { } } - results + Ok(results) } pub fn get_full_vector<'arena>( @@ -550,7 +566,10 @@ impl VectorCore { let label_to_index = self.label_to_index.read().unwrap(); let global_to_local_id = self.global_to_local_id.read().unwrap(); - let (item_id, label) = global_to_local_id.get(&id).unwrap(); + let (item_id, label) = global_to_local_id + .get(&id) + .ok_or_else(|| VectorError::VectorNotFound(format!("Vector {id} not found")))?; + let index = label_to_index.get(label).unwrap(); let label = arena.alloc_str(label); @@ -607,7 +626,6 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { let mut result = bumpalo::collections::Vec::new_in(arena); - let local_to_global_id = self.local_to_global_id.read().unwrap(); let label_to_index = self.label_to_index.read().unwrap(); let index = label_to_index.get(label).unwrap(); @@ -616,7 +634,11 @@ impl VectorCore { if get_vector_data { while let Some((key, item)) = iter.next().transpose()? { - let &id = local_to_global_id.get(&key.item).unwrap(); + let id = self + .local_to_global_id + .get(txn, &key.item)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + result.push(HVector { id, label, @@ -630,7 +652,11 @@ impl VectorCore { } } else { while let Some(key) = iter.next_id().transpose()? { - let &id = local_to_global_id.get(&key.item).unwrap(); + let id = self + .local_to_global_id + .get(txn, &key.item)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + result.push(HVector { id, label, @@ -654,7 +680,6 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { let label_to_index = self.label_to_index.read().unwrap(); - let local_to_global_id = self.local_to_global_id.read().unwrap(); let mut result = bumpalo::collections::Vec::new_in(arena); for (label, index) in label_to_index.iter() { @@ -663,7 +688,13 @@ impl VectorCore { if get_vector_data { while let Some((key, item)) = iter.next().transpose()? { - let &id = local_to_global_id.get(&key.item).unwrap(); + let id = self + .local_to_global_id + .get(txn, &key.item)? + .ok_or_else(|| { + VectorError::VectorNotFound("Vector not found".to_string()) + })?; + result.push(HVector { id, label: arena.alloc_str(label), @@ -677,7 +708,13 @@ impl VectorCore { } } else { while let Some(key) = iter.next_id().transpose()? { - let &id = local_to_global_id.get(&key.item).unwrap(); + let id = self + .local_to_global_id + .get(txn, &key.item)? + .ok_or_else(|| { + VectorError::VectorNotFound("Vector not found".to_string()) + })?; + result.push(HVector { id, label: arena.alloc_str(label), @@ -694,4 +731,22 @@ impl VectorCore { Ok(result) } + + pub fn into_global_id( + &self, + txn: &RoTxn, + searched: &Searched, + ) -> VectorCoreResult> { + let mut result = Vec::new(); + for &(id, distance) in searched.nns.iter() { + result.push(( + self.local_to_global_id + .get(txn, &id)? + .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?, + distance, + )) + } + + Ok(result) + } } diff --git a/helix-db/src/helix_engine/vector_core/reader.rs b/helix-db/src/helix_engine/vector_core/reader.rs index 3cb736576..9da913d5e 100644 --- a/helix-db/src/helix_engine/vector_core/reader.rs +++ b/helix-db/src/helix_engine/vector_core/reader.rs @@ -4,7 +4,6 @@ use std::marker; use std::num::NonZeroUsize; use bumpalo::collections::CollectIn; -use hashbrown::HashMap; use heed3::RoTxn; use heed3::types::DecodeIgnore; use min_max_heap::MinMaxHeap; @@ -61,13 +60,6 @@ impl<'arena> Searched<'arena> { pub fn into_nns(self) -> bumpalo::collections::Vec<'arena, (ItemId, f32)> { self.nns } - - pub fn into_global_id(&self, map: &HashMap) -> Vec<(u128, f32)> { - self.nns - .iter() - .map(|(item_id, score)| (*map.get(item_id).unwrap(), *score)) - .collect() - } } /// Options used to make a query against an hannoy [`Reader`]. From dcfae9fc727d9c28a45efe6d13d627de39817979 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 10:16:23 -0300 Subject: [PATCH 44/48] Make creating a HVector from raw bytes a failable operation --- helix-db/src/helix_engine/vector_core/mod.rs | 11 +++++++---- .../src/protocol/custom_serde/error_handling_tests.rs | 6 +++--- .../src/protocol/custom_serde/property_based_tests.rs | 2 +- helix-db/src/protocol/custom_serde/vector_serde.rs | 3 ++- .../src/protocol/custom_serde/vector_serde_tests.rs | 4 ++-- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 6ea05c679..d72a90bdf 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -230,10 +230,13 @@ impl<'arena> HVector<'arena> { pub fn raw_vector_data_to_vec<'txn>( raw_vector_data: &'txn [u8], arena: &'arena bumpalo::Bump, - ) -> bumpalo::collections::Vec<'arena, f32> { + ) -> VectorCoreResult> { let mut bump_vec = bumpalo::collections::Vec::<'arena, f32>::new_in(arena); - bump_vec.extend_from_slice(bytemuck::cast_slice(raw_vector_data)); - bump_vec + bump_vec.extend_from_slice(bytemuck::try_cast_slice(raw_vector_data).map_err(|err| { + VectorError::ConversionError(format!("Error casting raw bytes to &[f32]: {}", err)) + })?); + + Ok(bump_vec) } pub fn from_raw_vector_data<'txn>( @@ -510,7 +513,7 @@ impl VectorCore { let (item_id, _) = nns.first().unwrap(); let global_id = self .local_to_global_id - .get(txn, &item_id)? + .get(txn, item_id)? .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; let (_, label) = global_to_local_id.get(&global_id).unwrap(); let index = label_to_index.get(label).unwrap(); diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index 93b11dc5a..5d761d603 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -228,7 +228,7 @@ mod error_handling_tests { fn test_vector_cast_empty_raw_data_panics() { let arena = Bump::new(); let empty_data: &[u8] = &[]; - HVector::raw_vector_data_to_vec(empty_data, &arena); + HVector::raw_vector_data_to_vec(empty_data, &arena).unwrap(); } #[test] @@ -260,12 +260,12 @@ mod error_handling_tests { } #[test] - #[should_panic(expected = "is not a multiple of size_of::()")] + #[should_panic] fn test_vector_misaligned_data_bytes_panics() { let arena = Bump::new(); // 7 bytes is not a multiple of 4 (size of f32) let misaligned: &[u8] = &[0, 1, 2, 3, 4, 5, 6]; - HVector::raw_vector_data_to_vec(misaligned, &arena); + HVector::raw_vector_data_to_vec(misaligned, &arena).unwrap(); } #[test] diff --git a/helix-db/src/protocol/custom_serde/property_based_tests.rs b/helix-db/src/protocol/custom_serde/property_based_tests.rs index f88495241..0304a03ae 100644 --- a/helix-db/src/protocol/custom_serde/property_based_tests.rs +++ b/helix-db/src/protocol/custom_serde/property_based_tests.rs @@ -344,7 +344,7 @@ mod property_based_tests { // Convert to bytes and back let bytes = create_vector_bytes(&data); - let restored = HVector::raw_vector_data_to_vec( &bytes,&arena); + let restored = HVector::raw_vector_data_to_vec( &bytes,&arena).unwrap(); prop_assert_eq!(restored.len(), data.len()); diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index 4e7fb214b..2f184d622 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -94,7 +94,8 @@ impl<'de, 'txn, 'arena> serde::de::DeserializeSeed<'de> for VectorDeSeed<'txn, ' .next_element_seed(OptionPropertiesMapDeSeed { arena: self.arena })? .ok_or_else(|| serde::de::Error::custom("Expected properties field"))?; - let data = HVector::raw_vector_data_to_vec(self.raw_vector_data, self.arena); + let data = HVector::raw_vector_data_to_vec(self.raw_vector_data, self.arena) + .map_err(serde::de::Error::custom)?; Ok(HVector { id: self.id, diff --git a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs index 5d0ad19d1..706c67fe9 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde_tests.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde_tests.rs @@ -175,7 +175,7 @@ mod vector_serialization_tests { let original_data: Vec = (0..128).map(|i| i as f32).collect(); let raw_bytes = create_vector_bytes(&original_data); - let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena); + let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena).unwrap(); assert_eq!(casted_data.len(), 128); for (i, &val) in casted_data.iter().enumerate() { @@ -189,7 +189,7 @@ mod vector_serialization_tests { let original_data = vec![3.14159, 2.71828, 1.41421, 1.73205]; let raw_bytes = create_vector_bytes(&original_data); - let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena); + let casted_data = HVector::raw_vector_data_to_vec(&raw_bytes, &arena).unwrap(); assert_eq!(casted_data.len(), original_data.len()); for (orig, casted) in original_data.iter().zip(casted_data.iter()) { From a96817094239db00251c73ca23c9bdc52ed54002 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 14:27:43 -0300 Subject: [PATCH 45/48] Add support for property filter queries --- helix-db/src/helix_engine/bm25/bm25.rs | 11 +-- .../hnsw_concurrent_tests.rs | 25 ++---- helix-db/src/helix_engine/tests/hnsw_tests.rs | 8 +- .../tests/traversal_tests/drop_tests.rs | 4 - .../traversal_tests/edge_traversal_tests.rs | 4 - .../traversal_core/ops/vectors/search.rs | 29 ++++--- helix-db/src/helix_engine/vector_core/mod.rs | 84 ++++++++++++++----- .../src/protocol/custom_serde/vector_serde.rs | 4 +- hql-tests/tests/rerankers/queries.hx | 4 +- 9 files changed, 100 insertions(+), 73 deletions(-) diff --git a/helix-db/src/helix_engine/bm25/bm25.rs b/helix-db/src/helix_engine/bm25/bm25.rs index 7ca41bc6c..46b117cfb 100644 --- a/helix-db/src/helix_engine/bm25/bm25.rs +++ b/helix-db/src/helix_engine/bm25/bm25.rs @@ -418,14 +418,9 @@ impl HybridSearch for HelixGraphStorage { task::spawn_blocking(move || -> Result>, GraphError> { let txn = graph_env_vector.read_txn()?; let arena = Bump::new(); // MOVE - let results = self.vectors.search( - &txn, - query_vector_owned, - limit * 2, - "vector", - false, - &arena, - )?; + let results = + self.vectors + .search(&txn, query_vector_owned, limit * 2, "vector", &arena)?; let scores = self.vectors.into_global_id(&txn, &results)?; Ok(Some(scores)) }); diff --git a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs index 9a05cb434..07c6a6175 100644 --- a/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs +++ b/helix-db/src/helix_engine/tests/concurrency_tests/hnsw_concurrent_tests.rs @@ -13,15 +13,13 @@ /// - Delete during search might return inconsistent results /// - LMDB transaction model provides MVCC but needs validation use bumpalo::Bump; -use heed3::{Env, EnvOpenOptions, RoTxn, RwTxn}; +use heed3::{Env, EnvOpenOptions, RwTxn}; use rand::Rng; use std::sync::{Arc, Barrier}; use std::thread; use tempfile::TempDir; -use crate::helix_engine::vector_core::{HNSWConfig, HVector, VectorCore}; - -type Filter = fn(&HVector, &RoTxn) -> bool; +use crate::helix_engine::vector_core::{HNSWConfig, VectorCore}; /// Setup test environment with larger map size for concurrent access /// @@ -139,7 +137,7 @@ fn test_concurrent_inserts_single_label() { // Additional consistency check: Verify we can perform searches (entry point exists implicitly) let arena = Bump::new(); let query = [0.5; 128]; - let search_result = index.search(&rtxn, query.to_vec(), 10, "concurrent_test", false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 10, "concurrent_test", &arena); assert!( search_result.is_ok(), "Should be able to search after concurrent inserts (entry point exists)" @@ -203,7 +201,7 @@ fn test_concurrent_searches_during_inserts() { let rtxn = env.read_txn().unwrap(); let arena = Bump::new(); - match index.search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) { + match index.search(&rtxn, query.to_vec(), 10, "search_test", &arena) { Ok(results) => { total_searches += 1; total_results += results.nns.len(); @@ -278,7 +276,7 @@ fn test_concurrent_searches_during_inserts() { // Verify we can still search successfully let arena = Bump::new(); let results = index - .search(&rtxn, query.to_vec(), 10, "search_test", false, &arena) + .search(&rtxn, query.to_vec(), 10, "search_test", &arena) .unwrap(); assert!( !results.nns.is_empty(), @@ -351,7 +349,7 @@ fn test_concurrent_inserts_multiple_labels() { // Verify we can search for each label (entry point exists implicitly) let query = [0.5; 64]; - let search_result = index.search(&rtxn, query.to_vec(), 5, &label, false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 5, &label, &arena); assert!( search_result.is_ok(), "Should be able to search label {}", @@ -431,7 +429,7 @@ fn test_entry_point_consistency() { // If we can successfully search, entry point must be valid let query = [0.5; 32]; - let search_result = index.search(&rtxn, query.to_vec(), 10, "entry_test", false, &arena); + let search_result = index.search(&rtxn, query.to_vec(), 10, "entry_test", &arena); assert!( search_result.is_ok(), "Entry point should exist and be valid" @@ -514,14 +512,7 @@ fn test_graph_connectivity_after_concurrent_inserts() { for i in 0..10 { let query = random_vector(64); let results = index - .search( - &rtxn, - query.to_vec(), - 10, - "connectivity_test", - false, - &arena, - ) + .search(&rtxn, query.to_vec(), 10, "connectivity_test", &arena) .unwrap(); assert!( diff --git a/helix-db/src/helix_engine/tests/hnsw_tests.rs b/helix-db/src/helix_engine/tests/hnsw_tests.rs index 8563cff88..7b030cd6f 100644 --- a/helix-db/src/helix_engine/tests/hnsw_tests.rs +++ b/helix-db/src/helix_engine/tests/hnsw_tests.rs @@ -1,11 +1,9 @@ use bumpalo::Bump; -use heed3::{Env, EnvOpenOptions, RoTxn}; +use heed3::{Env, EnvOpenOptions}; use rand::Rng; use tempfile::TempDir; -use crate::helix_engine::vector_core::{HNSWConfig, HVector, VectorCore}; - -type Filter = fn(&HVector, &RoTxn) -> bool; +use crate::helix_engine::vector_core::{HNSWConfig, VectorCore}; fn setup_env() -> (Env, TempDir) { let temp_dir = tempfile::tempdir().unwrap(); @@ -59,7 +57,7 @@ fn test_hnsw_search_returns_results() { let txn = env.read_txn().unwrap(); let query = [0.5, 0.5, 0.5, 0.5]; let results = index - .search(&txn, query.to_vec(), 5, "vector", false, &arena) + .search(&txn, query.to_vec(), 5, "vector", &arena) .unwrap(); assert!(!results.nns.is_empty()); } diff --git a/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs index ce6343e81..555610417 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/drop_tests.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use bumpalo::Bump; -use heed3::RoTxn; use rand::Rng; use tempfile::TempDir; @@ -25,13 +24,10 @@ use crate::{ traversal_value::TraversalValue, }, types::GraphError, - vector_core::HVector, }, props, }; -type Filter = fn(&HVector, &RoTxn) -> bool; - fn setup_test_db() -> (TempDir, Arc) { let temp_dir = TempDir::new().unwrap(); let db_path = temp_dir.path().to_str().unwrap(); diff --git a/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs index dbad15e83..ec676a0f3 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/edge_traversal_tests.rs @@ -22,14 +22,10 @@ use crate::{ traversal_value::TraversalValue, }, types::GraphError, - vector_core::HVector, }, props, protocol::value::Value, }; -use heed3::RoTxn; - -type Filter = fn(&HVector, &RoTxn) -> bool; fn setup_test_db() -> (TempDir, Arc) { let temp_dir = TempDir::new().unwrap(); diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index 386b43f27..8dcaa4fc7 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -15,7 +15,7 @@ pub trait SearchVAdapter<'db, 'arena, 'txn>: query: &'arena [f32], k: K, label: &'arena str, - filter: Option<&'arena [F]>, + filter: Option, ) -> RoTraversalIterator< 'db, 'arena, @@ -37,7 +37,7 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE query: &'arena [f32], k: K, label: &'arena str, - _filter: Option<&'arena [F]>, + filter: Option, ) -> RoTraversalIterator< 'db, 'arena, @@ -55,7 +55,6 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE query.to_vec(), k.try_into().unwrap(), label, - false, self.arena, ); @@ -67,13 +66,23 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE false, self.arena, ) { - Ok(hvectors) => hvectors - .into_iter() - .map(|vector| { - Ok::(TraversalValue::Vector(vector)) - }) - .collect::>() - .into_iter(), + Ok(hvectors) => match filter { + Some(filter) => hvectors + .into_iter() + .filter(|vector| filter(vector, self.txn)) + .map(|vector| { + Ok::(TraversalValue::Vector(vector)) + }) + .collect::>() + .into_iter(), + None => hvectors + .into_iter() + .map(|vector| { + Ok::(TraversalValue::Vector(vector)) + }) + .collect::>() + .into_iter(), + }, Err(err) => { let error = GraphError::VectorError(format!("{err}")); once(Err(error)).collect::>().into_iter() diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index d72a90bdf..49065cc6a 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -28,7 +28,9 @@ use crate::{ }, }, protocol::{ - custom_serde::vector_serde::{VectoWithoutDataDeSeed, VectorDeSeed}, + custom_serde::vector_serde::{ + OptionPropertiesMapDeSeed, VectoWithoutDataDeSeed, VectorDeSeed, + }, value::Value, }, utils::{ @@ -123,10 +125,9 @@ impl<'arena> HVector<'arena> { } pub fn from_vec(label: &'arena str, data: bumpalo::collections::Vec<'arena, f32>) -> Self { - let id = v6_uuid(); HVector { - id, label, + id: v6_uuid(), version: 1, data: Some(Item::::from_vec(data)), distance: None, @@ -377,7 +378,6 @@ impl VectorCore { query: Vec, k: usize, label: &'arena str, - _should_trickle: bool, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { match self.label_to_index.read().unwrap().get(label) { @@ -430,7 +430,7 @@ impl VectorCore { txn: &mut RwTxn, label: &'arena str, data: &'arena [f32], - _properties: Option>, + properties: Option>, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { let writer = self.get_writer_or_create_index(label, data.len(), txn)?; @@ -442,7 +442,8 @@ impl VectorCore { let mut bump_vec = bumpalo::collections::Vec::new_in(arena); bump_vec.extend_from_slice(data); - let hvector = HVector::from_vec(label, bump_vec); + let mut hvector = HVector::from_vec(label, bump_vec); + hvector.properties = properties; self.global_to_local_id .write() @@ -461,6 +462,9 @@ impl VectorCore { let mut rng = StdRng::from_os_rng(); let mut builder = writer.builder(&mut rng); + self.vector_properties_db + .put(txn, &hvector.id, &bincode::serialize(&properties)?)?; + // FIXME: We shouldn't rebuild on every insertion builder .ef_construction(self.config.ef_construct) @@ -526,14 +530,25 @@ impl VectorCore { .get(txn, &item_id)? .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + let properties = match self.vector_properties_db.get(txn, &global_id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; + results.push(HVector { id: global_id, distance: Some(distance), label, + properties, deleted: false, level: None, version: 0, - properties: None, data: get_item(self.hsnw, index.id, txn, item_id).unwrap(), }); } @@ -544,13 +559,23 @@ impl VectorCore { .get(txn, &item_id)? .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + let properties = match self.vector_properties_db.get(txn, &global_id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; results.push(HVector { id: global_id, distance: Some(distance), label, deleted: false, version: 0, - properties: None, + properties, level: None, data: None, }); @@ -563,51 +588,68 @@ impl VectorCore { pub fn get_full_vector<'arena>( &self, txn: &RoTxn, - id: u128, + global_id: u128, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult> { let label_to_index = self.label_to_index.read().unwrap(); let global_to_local_id = self.global_to_local_id.read().unwrap(); let (item_id, label) = global_to_local_id - .get(&id) - .ok_or_else(|| VectorError::VectorNotFound(format!("Vector {id} not found")))?; + .get(&global_id) + .ok_or_else(|| VectorError::VectorNotFound(format!("Vector {global_id} not found")))?; let index = label_to_index.get(label).unwrap(); - let label = arena.alloc_str(label); - - let item = get_item(self.hsnw, index.id, txn, *item_id)?.map(|i| i.clone_in(arena)); + let properties = match self.vector_properties_db.get(txn, &global_id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; Ok(HVector { - id, + id: global_id, + properties, distance: None, - label, + label: arena.alloc_str(label), deleted: false, version: 0, level: None, - properties: None, - data: item.clone(), + data: get_item(self.hsnw, index.id, txn, *item_id)?.map(|i| i.clone_in(arena)), }) } pub fn get_vector_properties<'arena>( &self, - _txn: &RoTxn, + txn: &RoTxn, id: u128, arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { let global_to_local_id = self.global_to_local_id.read().unwrap(); let (_, label) = global_to_local_id.get(&id).unwrap(); - // todo: actually take properties + let properties = match self.vector_properties_db.get(txn, &id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; + Ok(Some(HVector { id, + properties, distance: None, label: arena.alloc_str(label.as_str()), deleted: false, version: 0, level: None, - properties: None, data: None, })) } diff --git a/helix-db/src/protocol/custom_serde/vector_serde.rs b/helix-db/src/protocol/custom_serde/vector_serde.rs index 2f184d622..68a88eb17 100644 --- a/helix-db/src/protocol/custom_serde/vector_serde.rs +++ b/helix-db/src/protocol/custom_serde/vector_serde.rs @@ -6,8 +6,8 @@ use serde::de::{DeserializeSeed, Visitor}; use std::fmt; /// Helper DeserializeSeed for Option -struct OptionPropertiesMapDeSeed<'arena> { - arena: &'arena bumpalo::Bump, +pub struct OptionPropertiesMapDeSeed<'arena> { + pub arena: &'arena bumpalo::Bump, } impl<'de, 'arena> DeserializeSeed<'de> for OptionPropertiesMapDeSeed<'arena> { diff --git a/hql-tests/tests/rerankers/queries.hx b/hql-tests/tests/rerankers/queries.hx index 152fb10be..70468441f 100644 --- a/hql-tests/tests/rerankers/queries.hx +++ b/hql-tests/tests/rerankers/queries.hx @@ -16,7 +16,7 @@ QUERY testRRFDefault(query_vec: [F32]) => RETURN results // Test 2: RerankRRF with custom k parameter -QUERY testRRFCustomK(query_vec: [F32], k_val: F64) => +QUERY testRRFCustomK(query_vec: [F32], k_val: F32) => results <- SearchV(query_vec, 100) ::RerankRRF(k: k_val) ::RANGE(0, 10) @@ -52,7 +52,7 @@ QUERY testChainedRerankers(query_vec: [F32]) => RETURN results // Test 7: MMR with variable lambda -QUERY testMMRVariableLambda(query_vec: [F32], lambda_val: F64) => +QUERY testMMRVariableLambda(query_vec: [F32], lambda_val: F32) => results <- SearchV(query_vec, 100) ::RerankMMR(lambda: lambda_val) ::RANGE(0, 10) From 7cc375d8192205c25b63e74211becf23337f88c5 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 18:46:50 -0300 Subject: [PATCH 46/48] Use f32 as default value for traversal --- .../traversal_tests/shortest_path_tests.rs | 16 ++-- .../traversal_core/ops/util/paths.rs | 86 +++++++++---------- .../src/helixc/generator/math_functions.rs | 22 ++--- .../src/helixc/generator/traversal_steps.rs | 8 +- helix-db/src/protocol/value.rs | 5 ++ 5 files changed, 71 insertions(+), 66 deletions(-) diff --git a/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs index b84b063a1..f6ba53451 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs @@ -261,8 +261,8 @@ fn test_dijkstra_custom_weight_function() { .as_ref() .and_then(|props| props.get("distance")) .and_then(|v| match v { - Value::F64(f) => Some(*f), - Value::F32(f) => Some(*f as f64), + Value::F64(f) => Some(*f as f32), + Value::F32(f) => Some(*f), _ => None, }) .ok_or_else(|| { @@ -399,8 +399,8 @@ fn test_dijkstra_multi_context_weight() { .as_ref() .and_then(|props| props.get("distance")) .and_then(|v| match v { - Value::F64(f) => Some(*f), - Value::F32(f) => Some(*f as f64), + Value::F64(f) => Some(*f as f32), + Value::F32(f) => Some(*f), _ => None, }) .ok_or_else(|| { @@ -414,8 +414,8 @@ fn test_dijkstra_multi_context_weight() { .as_ref() .and_then(|props| props.get("traffic_factor")) .and_then(|v| match v { - Value::F64(f) => Some(*f), - Value::F32(f) => Some(*f as f64), + Value::F64(f) => Some(*f as f32), + Value::F32(f) => Some(*f), _ => None, }) .ok_or_else(|| { @@ -879,13 +879,13 @@ fn test_astar_custom_weight_and_heuristic() { .ok_or(crate::helix_engine::types::GraphError::New( "distance property not found".to_string(), ))? - .as_f64(); + .as_f32(); let traffic = edge .get_property("traffic") .ok_or(crate::helix_engine::types::GraphError::New( "traffic property not found".to_string(), ))? - .as_f64(); + .as_f32(); Ok(distance * traffic) }; diff --git a/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs b/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs index afb2d8d21..aa1024ac1 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/util/paths.rs @@ -22,23 +22,23 @@ pub fn default_weight_fn<'arena>( edge: &Edge<'arena>, _src_node: &Node<'arena>, _dst_node: &Node<'arena>, -) -> Result { +) -> Result { Ok(edge .properties .as_ref() .and_then(|props| props.get("weight")) .and_then(|w| match w { - Value::F32(f) => Some(*f as f64), - Value::F64(f) => Some(*f), - Value::I8(i) => Some(*i as f64), - Value::I16(i) => Some(*i as f64), - Value::I32(i) => Some(*i as f64), - Value::I64(i) => Some(*i as f64), - Value::U8(i) => Some(*i as f64), - Value::U16(i) => Some(*i as f64), - Value::U32(i) => Some(*i as f64), - Value::U64(i) => Some(*i as f64), - Value::U128(i) => Some(*i as f64), + Value::F32(f) => Some(*f), + Value::F64(f) => Some(*f as f32), + Value::I8(i) => Some(*i as f32), + Value::I16(i) => Some(*i as f32), + Value::I32(i) => Some(*i as f32), + Value::I64(i) => Some(*i as f32), + Value::U8(i) => Some(*i as f32), + Value::U16(i) => Some(*i as f32), + Value::U32(i) => Some(*i as f32), + Value::U64(i) => Some(*i as f32), + Value::U128(i) => Some(*i as f32), _ => None, }) .unwrap_or(1.0)) @@ -49,22 +49,22 @@ pub fn default_weight_fn<'arena>( pub fn property_heuristic<'arena>( node: &Node<'arena>, property_name: &str, -) -> Result { +) -> Result { node.properties .as_ref() .and_then(|props| props.get(property_name)) .and_then(|v| match v { - Value::F64(f) => Some(*f), - Value::F32(f) => Some(*f as f64), - Value::I64(i) => Some(*i as f64), - Value::I32(i) => Some(*i as f64), - Value::I16(i) => Some(*i as f64), - Value::I8(i) => Some(*i as f64), - Value::U128(i) => Some(*i as f64), - Value::U64(i) => Some(*i as f64), - Value::U32(i) => Some(*i as f64), - Value::U16(i) => Some(*i as f64), - Value::U8(i) => Some(*i as f64), + Value::F32(f) => Some(*f), + Value::F64(f) => Some(*f as f32), + Value::I64(i) => Some(*i as f32), + Value::I32(i) => Some(*i as f32), + Value::I16(i) => Some(*i as f32), + Value::I8(i) => Some(*i as f32), + Value::U128(i) => Some(*i as f32), + Value::U64(i) => Some(*i as f32), + Value::U32(i) => Some(*i as f32), + Value::U16(i) => Some(*i as f32), + Value::U8(i) => Some(*i as f32), _ => None, }) .ok_or_else(|| { @@ -94,12 +94,12 @@ pub struct ShortestPathIterator< 'txn, I, F, - H = fn(&Node<'arena>) -> Result, + H = fn(&Node<'arena>) -> Result, > where 'db: 'arena, 'arena: 'txn, - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, - H: Fn(&Node<'arena>) -> Result, + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + H: Fn(&Node<'arena>) -> Result, { pub arena: &'arena bumpalo::Bump, pub iter: I, @@ -115,7 +115,7 @@ pub struct ShortestPathIterator< #[derive(Debug, Clone)] struct DijkstraState { node_id: u128, - distance: f64, + distance: f32, } impl Eq for DijkstraState {} @@ -145,8 +145,8 @@ impl PartialOrd for DijkstraState { #[derive(Debug, Clone)] struct AStarState { node_id: u128, - g_score: f64, - f_score: f64, + g_score: f32, + f_score: f32, } impl Eq for AStarState {} @@ -179,8 +179,8 @@ impl< 'arena: 'txn, 'txn, I: Iterator, GraphError>>, - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, - H: Fn(&Node<'arena>) -> Result, + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + H: Fn(&Node<'arena>) -> Result, > Iterator for ShortestPathIterator<'db, 'arena, 'txn, I, F, H> { type Item = Result, GraphError>; @@ -208,8 +208,8 @@ impl< impl<'db, 'arena, 'txn, I, F, H> ShortestPathIterator<'db, 'arena, 'txn, I, F, H> where - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, - H: Fn(&Node<'arena>) -> Result, + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + H: Fn(&Node<'arena>) -> Result, { fn reconstruct_path( &self, @@ -405,7 +405,7 @@ where }; let mut heap = BinaryHeap::new(); - let mut g_scores: HashMap = HashMap::with_capacity(64); + let mut g_scores: HashMap = HashMap::with_capacity(64); let mut parent: HashMap = HashMap::with_capacity(32); // Calculate initial heuristic for start node @@ -550,7 +550,7 @@ pub trait ShortestPathAdapter<'db, 'arena, 'txn, 's, I>: 'arena, 'txn, I, - fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, >, >; @@ -563,7 +563,7 @@ pub trait ShortestPathAdapter<'db, 'arena, 'txn, 's, I>: weight_fn: F, ) -> RoTraversalIterator<'db, 'arena, 'txn, ShortestPathIterator<'db, 'arena, 'txn, I, F>> where - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result; + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result; fn shortest_path_astar( self, @@ -574,8 +574,8 @@ pub trait ShortestPathAdapter<'db, 'arena, 'txn, 's, I>: heuristic_fn: H, ) -> RoTraversalIterator<'db, 'arena, 'txn, ShortestPathIterator<'db, 'arena, 'txn, I, F, H>> where - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, - H: Fn(&Node<'arena>) -> Result; + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + H: Fn(&Node<'arena>) -> Result; } impl<'db, 'arena, 'txn, 's, I: Iterator, GraphError>>> @@ -596,7 +596,7 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr 'arena, 'txn, I, - fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, >, > { self.shortest_path_with_algorithm( @@ -618,7 +618,7 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr weight_fn: F, ) -> RoTraversalIterator<'db, 'arena, 'txn, ShortestPathIterator<'db, 'arena, 'txn, I, F>> where - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, { RoTraversalIterator { arena: self.arena, @@ -652,8 +652,8 @@ impl<'db, 'arena, 'txn, 's, I: Iterator, Gr heuristic_fn: H, ) -> RoTraversalIterator<'db, 'arena, 'txn, ShortestPathIterator<'db, 'arena, 'txn, I, F, H>> where - F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, - H: Fn(&Node<'arena>) -> Result, + F: Fn(&Edge<'arena>, &Node<'arena>, &Node<'arena>) -> Result, + H: Fn(&Node<'arena>) -> Result, { RoTraversalIterator { arena: self.arena, diff --git a/helix-db/src/helixc/generator/math_functions.rs b/helix-db/src/helixc/generator/math_functions.rs index 4aa72d7c5..3aade7204 100644 --- a/helix-db/src/helixc/generator/math_functions.rs +++ b/helix-db/src/helixc/generator/math_functions.rs @@ -68,28 +68,28 @@ impl Display for PropertyAccess { PropertyContext::Edge => { write!( f, - "(edge.get_property({}).ok_or(GraphError::Default)?.as_f64())", + "(edge.get_property({}).ok_or(GraphError::Default)?.as_f32())", self.property ) } PropertyContext::SourceNode => { write!( f, - "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", + "(src_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", self.property ) } PropertyContext::TargetNode => { write!( f, - "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f64())", + "(dst_node.get_property({}).ok_or(GraphError::Default)?.as_f32())", self.property ) } PropertyContext::Current => { write!( f, - "(v.get_property({}).ok_or(GraphError::Default)?.as_f64())", + "(v.get_property({}).ok_or(GraphError::Default)?.as_f32())", self.property ) } @@ -469,7 +469,7 @@ mod tests { }; assert_eq!( edge_prop.to_string(), - "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64())" + "(edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32())" ); // Test SourceNode context @@ -479,7 +479,7 @@ mod tests { }; assert_eq!( src_prop.to_string(), - "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64())" + "(src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32())" ); // Test TargetNode context @@ -489,14 +489,14 @@ mod tests { }; assert_eq!( dst_prop.to_string(), - "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f64())" + "(dst_node.get_property(\"popularity\").ok_or(GraphError::Default)?.as_f32())" ); } #[test] fn test_complex_weight_expression() { // Test: MUL(_::{distance}, POW(0.95, DIV(_::{days}, 30))) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (0.95_f32).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f64()) / 30_f32))) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property("days").ok_or(GraphError::Default)?.as_f32()) / 30_f32))) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -525,14 +525,14 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (0.95_f32).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f64()) / 30_f32)))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (0.95_f32).powf(((edge.get_property(\"days\").ok_or(GraphError::Default)?.as_f32()) / 30_f32)))" ); } #[test] fn test_multi_context_expression() { // Test: MUL(_::{distance}, _::From::{traffic_factor}) - // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f64())) + // Should generate: ((edge.get_property("distance").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property("traffic_factor").ok_or(GraphError::Default)?.as_f32())) let expr = MathFunctionCallGen { function: MathFunction::Mul, args: vec![ @@ -549,7 +549,7 @@ mod tests { assert_eq!( expr.to_string(), - "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f64()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f64()))" + "((edge.get_property(\"distance\").ok_or(GraphError::Default)?.as_f32()) * (src_node.get_property(\"traffic_factor\").ok_or(GraphError::Default)?.as_f32()))" ); } } diff --git a/helix-db/src/helixc/generator/traversal_steps.rs b/helix-db/src/helixc/generator/traversal_steps.rs index 92a55c7c0..356c3f513 100644 --- a/helix-db/src/helixc/generator/traversal_steps.rs +++ b/helix-db/src/helixc/generator/traversal_steps.rs @@ -724,14 +724,14 @@ impl Display for ShortestPathDijkstras { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}", expr )?; } @@ -786,14 +786,14 @@ impl Display for ShortestPathAStar { WeightCalculation::Property(prop) => { write!( f, - "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f64()?) }}, ", + "|edge, _src_node, _dst_node| -> Result {{ Ok(edge.get_property({})?.as_f32()?) }}, ", prop )?; } WeightCalculation::Expression(expr) => { write!( f, - "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", + "|edge, src_node, dst_node| -> Result {{ Ok({}) }}, ", expr )?; } diff --git a/helix-db/src/protocol/value.rs b/helix-db/src/protocol/value.rs index c1830ac22..88f2ed1e1 100644 --- a/helix-db/src/protocol/value.rs +++ b/helix-db/src/protocol/value.rs @@ -1634,6 +1634,11 @@ impl Value { pub fn as_f64(&self) -> f64 { *self.into_primitive() } + + #[inline(always)] + pub fn as_f32(&self) -> f32 { + *self.into_primitive() + } } #[cfg(test)] From 9b9d6a757f8faee6f90d850c0043831092244b13 Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 20:14:36 -0300 Subject: [PATCH 47/48] Fix some bugs What's life but a collection of tests to pass? --- helix-cli/src/tests/init_tests.rs | 6 +- .../traversal_tests/shortest_path_tests.rs | 19 ++++--- .../traversal_tests/vector_traversal_tests.rs | 6 +- .../src/helix_engine/tests/vector_tests.rs | 10 +++- .../vector_core/distance/cosine.rs | 7 ++- helix-db/src/helix_engine/vector_core/mod.rs | 57 +++++++++++++++++-- .../custom_serde/error_handling_tests.rs | 3 +- 7 files changed, 83 insertions(+), 25 deletions(-) diff --git a/helix-cli/src/tests/init_tests.rs b/helix-cli/src/tests/init_tests.rs index 54b09e8bc..9955e592b 100644 --- a/helix-cli/src/tests/init_tests.rs +++ b/helix-cli/src/tests/init_tests.rs @@ -1,6 +1,6 @@ use crate::commands::init::run; use std::fs; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use tempfile::TempDir; /// Helper function to create a temporary test directory @@ -9,7 +9,7 @@ fn setup_test_dir() -> TempDir { } /// Helper function to check if helix.toml exists and is valid -fn assert_helix_config_exists(project_dir: &PathBuf) { +fn assert_helix_config_exists(project_dir: &Path) { let config_path = project_dir.join("helix.toml"); assert!( config_path.exists(), @@ -104,7 +104,7 @@ async fn test_init_with_default_path() { .await; assert!(result.is_ok(), "Init with default path should succeed"); - assert_helix_config_exists(&temp_dir.path().to_path_buf()); + assert_helix_config_exists(temp_dir.path()); } #[tokio::test] diff --git a/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs index f6ba53451..f87b02ed1 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/shortest_path_tests.rs @@ -837,7 +837,7 @@ fn test_astar_custom_weight_and_heuristic() { let start = G::new_mut(&storage, &arena, &mut txn) .add_n( "city", - props_option(&arena, props!("name" => "start", "h" => 10.0)), + props_option(&arena, props!("name" => "start", "h" => 10.0_f32)), None, ) .collect::, _>>() @@ -857,7 +857,10 @@ fn test_astar_custom_weight_and_heuristic() { G::new_mut(&storage, &arena, &mut txn) .add_edge( "road", - props_option(&arena, props!("distance" => 100.0, "traffic" => 0.5)), + props_option( + &arena, + props!("distance" => 100.0_f32, "traffic" => 0.5_f32), + ), start, end, false, @@ -874,12 +877,12 @@ fn test_astar_custom_weight_and_heuristic() { let custom_weight = |edge: &crate::utils::items::Edge, _src: &crate::utils::items::Node, _dst: &crate::utils::items::Node| { - let distance = edge - .get_property("distance") - .ok_or(crate::helix_engine::types::GraphError::New( - "distance property not found".to_string(), - ))? - .as_f32(); + let a = + edge.get_property("distance") + .ok_or(crate::helix_engine::types::GraphError::New( + "distance property not found".to_string(), + ))?; + let distance = a.as_f32(); let traffic = edge .get_property("traffic") .ok_or(crate::helix_engine::types::GraphError::New( diff --git a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs index dd42e380e..2988de022 100644 --- a/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs +++ b/helix-db/src/helix_engine/tests/traversal_tests/vector_traversal_tests.rs @@ -171,8 +171,8 @@ fn test_drop_vector_removes_edges() { txn.commit().unwrap(); let arena = Bump::new(); - let txn = storage.graph_env.read_txn().unwrap(); - let vectors = G::new(&storage, &txn, &arena) + let read_txn = storage.graph_env.read_txn().unwrap(); + let vectors = G::new(&storage, &read_txn, &arena) .search_v::(&[0.5, 0.5, 0.5], 10, "vector", None) .collect::, _>>() .unwrap(); @@ -188,6 +188,8 @@ fn test_drop_vector_removes_edges() { .unwrap(); txn.commit().unwrap(); + drop(read_txn); + let arena = Bump::new(); let txn = storage.graph_env.read_txn().unwrap(); let remaining = G::new(&storage, &txn, &arena) diff --git a/helix-db/src/helix_engine/tests/vector_tests.rs b/helix-db/src/helix_engine/tests/vector_tests.rs index 48547db57..1c7214171 100644 --- a/helix-db/src/helix_engine/tests/vector_tests.rs +++ b/helix-db/src/helix_engine/tests/vector_tests.rs @@ -32,7 +32,13 @@ fn test_hvector_distance_min() { let v1 = alloc_vector(&arena, &[1.0, 2.0, 3.0]); let v2 = alloc_vector(&arena, &[1.0, 2.0, 3.0]); let distance = v2.distance_to(&v1).unwrap(); - assert_eq!(distance, MIN_DISTANCE); + println!("Distance {}", distance); + assert!( + (distance - MIN_DISTANCE).abs() < 1e-6, + "Distance {} is not close enough to MIN_DISTANCE ({})", + distance, + MIN_DISTANCE + ); } #[test] @@ -99,5 +105,5 @@ fn test_hvector_cosine_similarity() { let arena2 = Bump::new(); let v2 = alloc_vector(&arena2, &[4.0, 5.0, 6.0]); let similarity = v1.distance_to(&v2).unwrap(); - assert!((similarity - (1.0 - 0.9746318461970762)).abs() < 1e-9); + assert!((similarity - (1.0 - 0.9746318461970762)).abs() < 1e-7); } diff --git a/helix-db/src/helix_engine/vector_core/distance/cosine.rs b/helix-db/src/helix_engine/vector_core/distance/cosine.rs index 76d2ca1da..5b790960f 100644 --- a/helix-db/src/helix_engine/vector_core/distance/cosine.rs +++ b/helix-db/src/helix_engine/vector_core/distance/cosine.rs @@ -4,7 +4,10 @@ use bytemuck::{Pod, Zeroable}; use serde::Serialize; use crate::helix_engine::vector_core::{ - distance::Distance, node::Item, spaces::simple::dot_product, unaligned_vector::UnalignedVector, + distance::{Distance, MAX_DISTANCE}, + node::Item, + spaces::simple::dot_product, + unaligned_vector::UnalignedVector, }; /// The Cosine similarity is a measure of similarity between two @@ -55,7 +58,7 @@ impl Distance for Cosine { // cos = 1. -> 0.0 1.0 - cos } else { - 0.0 + MAX_DISTANCE } } diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 49065cc6a..61cafc81d 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -203,7 +203,16 @@ impl<'arena> HVector<'arena> { pub fn distance_to(&self, rhs: &HVector<'arena>) -> VectorCoreResult { match (self.data.as_ref(), rhs.data.as_ref()) { (None, _) | (_, None) => Err(VectorError::HasNoData), - (Some(a), Some(b)) => Ok(Cosine::distance(a, b)), + (Some(a), Some(b)) => { + if a.vector.len() != b.vector.len() { + return Err(VectorError::InvalidVecDimension { + expected: a.vector.len(), + received: b.vector.len(), + }); + } + + Ok(Cosine::distance(a, b)) + } } } @@ -474,7 +483,8 @@ impl VectorCore { } pub fn delete(&self, txn: &mut RwTxn, id: u128) -> VectorCoreResult<()> { - match self.global_to_local_id.read().unwrap().get(&id) { + let mut global_to_local_id = self.global_to_local_id.write().unwrap(); + match global_to_local_id.get(&id) { Some(&(idx, ref label)) => { let label_to_index = self.label_to_index.read().unwrap(); let index = label_to_index @@ -485,8 +495,16 @@ impl VectorCore { // TODO: do we actually need to delete here? self.local_to_global_id.delete(txn, &idx)?; + global_to_local_id.remove(&id); index.num_vectors.fetch_sub(1, atomic::Ordering::SeqCst); + + let mut rng = StdRng::from_os_rng(); + let mut builder = writer.builder(&mut rng); + + builder + .ef_construction(self.config.ef_construct) + .build(txn)?; Ok(()) } None => Err(VectorError::VectorNotFound(format!( @@ -629,7 +647,9 @@ impl VectorCore { arena: &'arena bumpalo::Bump, ) -> VectorCoreResult>> { let global_to_local_id = self.global_to_local_id.read().unwrap(); - let (_, label) = global_to_local_id.get(&id).unwrap(); + let (_, label) = global_to_local_id + .get(&id) + .ok_or_else(|| VectorError::VectorNotFound(format!("Vector not found: {}", id)))?; let properties = match self.vector_properties_db.get(txn, &id)? { Some(bytes) => bincode::options() @@ -672,7 +692,10 @@ impl VectorCore { ) -> VectorCoreResult>> { let mut result = bumpalo::collections::Vec::new_in(arena); let label_to_index = self.label_to_index.read().unwrap(); - let index = label_to_index.get(label).unwrap(); + let index = match label_to_index.get(label) { + Some(index) => index, + None => return Ok(bumpalo::collections::Vec::new_in(arena)), + }; let reader = Reader::open(txn, index.id, self.hsnw)?; let mut iter = reader.iter(txn)?; @@ -684,6 +707,17 @@ impl VectorCore { .get(txn, &key.item)? .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + let properties = match self.vector_properties_db.get(txn, &id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; + result.push(HVector { id, label, @@ -691,7 +725,7 @@ impl VectorCore { deleted: false, level: Some(key.layer as usize), version: 0, - properties: None, + properties, data: Some(item.clone_in(arena)), }); } @@ -702,6 +736,17 @@ impl VectorCore { .get(txn, &key.item)? .ok_or_else(|| VectorError::VectorNotFound("Vector not found".to_string()))?; + let properties = match self.vector_properties_db.get(txn, &id)? { + Some(bytes) => bincode::options() + .with_fixint_encoding() + .allow_trailing_bytes() + .deserialize_seed(OptionPropertiesMapDeSeed { arena }, bytes) + .map_err(|e| { + VectorError::ConversionError(format!("Error deserializing vector: {e}")) + })?, + None => None, + }; + result.push(HVector { id, label, @@ -709,7 +754,7 @@ impl VectorCore { deleted: false, level: Some(key.layer as usize), version: 0, - properties: None, + properties, data: None, }); } diff --git a/helix-db/src/protocol/custom_serde/error_handling_tests.rs b/helix-db/src/protocol/custom_serde/error_handling_tests.rs index 5d761d603..064ce9cd2 100644 --- a/helix-db/src/protocol/custom_serde/error_handling_tests.rs +++ b/helix-db/src/protocol/custom_serde/error_handling_tests.rs @@ -219,8 +219,7 @@ mod error_handling_tests { let arena2 = Bump::new(); let _result = - HVector::from_bincode_bytes(&arena2, Some(&props_bytes), empty_data, id, true); - // Should panic due to assertion in cast_raw_vector_data + HVector::from_bincode_bytes(&arena2, Some(&props_bytes), empty_data, id, true).unwrap(); } #[test] From 5a7bacc31aeb32f088f522b00ff2779fe7e5f1cf Mon Sep 17 00:00:00 2001 From: Diego Reis Date: Mon, 24 Nov 2025 21:43:00 -0300 Subject: [PATCH 48/48] Unbind nns_to_hvectors 'txn lifetime get_item bound HVector to 'txn lifetime, doing zero-copy, but this complicates some traversal operations. So now, when data is required, we allocates using the arena and therefore binds HVector's lifetime to the arena --- .../traversal_core/ops/vectors/search.rs | 4 +--- helix-db/src/helix_engine/vector_core/mod.rs | 12 +++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs index 8dcaa4fc7..023dbdb32 100644 --- a/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs +++ b/helix-db/src/helix_engine/traversal_core/ops/vectors/search.rs @@ -25,8 +25,7 @@ pub trait SearchVAdapter<'db, 'arena, 'txn>: where F: Fn(&HVector, &RoTxn) -> bool, K: TryInto, - K::Error: std::fmt::Debug, - 'txn: 'arena; + K::Error: std::fmt::Debug; } impl<'db, 'arena, 'txn, I: Iterator, GraphError>>> @@ -48,7 +47,6 @@ impl<'db, 'arena, 'txn, I: Iterator, GraphE F: Fn(&HVector, &RoTxn) -> bool, K: TryInto, K::Error: std::fmt::Debug, - 'txn: 'arena, { let vectors = self.storage.vectors.search( self.txn, diff --git a/helix-db/src/helix_engine/vector_core/mod.rs b/helix-db/src/helix_engine/vector_core/mod.rs index 61cafc81d..56e81e98d 100644 --- a/helix-db/src/helix_engine/vector_core/mod.rs +++ b/helix-db/src/helix_engine/vector_core/mod.rs @@ -514,16 +514,13 @@ impl VectorCore { } } - pub fn nns_to_hvectors<'arena, 'txn>( + pub fn nns_to_hvectors<'arena>( &self, - txn: &'txn RoTxn, + txn: &RoTxn, nns: bumpalo::collections::Vec<'arena, (ItemId, f32)>, with_data: bool, arena: &'arena bumpalo::Bump, - ) -> VectorCoreResult>> - where - 'txn: 'arena, - { + ) -> VectorCoreResult>> { let mut results = bumpalo::collections::Vec::<'arena, HVector<'arena>>::with_capacity_in( nns.len(), arena, @@ -567,7 +564,8 @@ impl VectorCore { deleted: false, level: None, version: 0, - data: get_item(self.hsnw, index.id, txn, item_id).unwrap(), + data: get_item(self.hsnw, index.id, txn, item_id)? + .map(|data| data.clone_in(arena)), }); } } else {