Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ export interface UserConfig extends CliOptions {
maxBytesPerQuery: number;
atlasTemporaryDatabaseUserLifetimeMs: number;
voyageApiKey: string;
vectorSearchDimensions: number;
vectorSearchSimilarityFunction: "cosine" | "euclidean" | "dotProduct";
}

export const defaultUserConfig: UserConfig = {
Expand Down Expand Up @@ -214,6 +216,8 @@ export const defaultUserConfig: UserConfig = {
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
voyageApiKey: "",
vectorSearchDimensions: 1024,
vectorSearchSimilarityFunction: "euclidean",
};

export const config = setupUserConfig({
Expand Down
53 changes: 39 additions & 14 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,33 @@ export interface ConnectionState {
connectedAtlasCluster?: AtlasClusterConnectionInfo;
}

export interface ConnectionStateConnected extends ConnectionState {
tag: "connected";
serviceProvider: NodeDriverServiceProvider;
export class ConnectionStateConnected implements ConnectionState {
public tag = "connected" as const;

constructor(
public serviceProvider: NodeDriverServiceProvider,
public connectionStringAuthType?: ConnectionStringAuthType,
public connectedAtlasCluster?: AtlasClusterConnectionInfo
) {}

private _isSearchSupported?: boolean;

public async isSearchSupported(): Promise<boolean> {
if (this._isSearchSupported === undefined) {
try {
const dummyDatabase = "test";
const dummyCollection = "test";
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
this._isSearchSupported = true;
} catch {
this._isSearchSupported = false;
}
}

return this._isSearchSupported;
}
}

export interface ConnectionStateConnecting extends ConnectionState {
Expand Down Expand Up @@ -199,12 +223,10 @@ export class MCPConnectionManager extends ConnectionManager {
});
}

return this.changeState("connection-success", {
tag: "connected",
connectedAtlasCluster: settings.atlas,
serviceProvider: await serviceProvider,
connectionStringAuthType,
});
return this.changeState(
"connection-success",
new ConnectionStateConnected(await serviceProvider, connectionStringAuthType, settings.atlas)
);
} catch (error: unknown) {
const errorReason = error instanceof Error ? error.message : `${error as string}`;
this.changeState("connection-error", {
Expand Down Expand Up @@ -270,11 +292,14 @@ export class MCPConnectionManager extends ConnectionManager {
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-success", {
...this.currentConnectionState,
tag: "connected",
serviceProvider: await this.currentConnectionState.serviceProvider,
});
this.changeState(
"connection-success",
new ConnectionStateConnected(
await this.currentConnectionState.serviceProvider,
this.currentConnectionState.connectionStringAuthType,
this.currentConnectionState.connectedAtlasCluster
)
);
}

this.logger.info({
Expand Down
22 changes: 9 additions & 13 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ export class Session extends EventEmitter<SessionEvents> {
return this.connectionManager.currentConnectionState.tag === "connected";
}

isSearchSupported(): Promise<boolean> {
const state = this.connectionManager.currentConnectionState;
if (state.tag === "connected") {
return state.isSearchSupported();
}

return Promise.resolve(false);
}

get serviceProvider(): NodeDriverServiceProvider {
if (this.isConnectedToMongoDB) {
const state = this.connectionManager.currentConnectionState as ConnectionStateConnected;
Expand All @@ -153,17 +162,4 @@ export class Session extends EventEmitter<SessionEvents> {
get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined {
return this.connectionManager.currentConnectionState.connectedAtlasCluster;
}

async isSearchIndexSupported(): Promise<boolean> {
try {
const dummyDatabase = `search-index-test-db-${Date.now()}`;
const dummyCollection = `search-index-test-coll-${Date.now()}`;
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
return true;
} catch {
return false;
}
}
}
2 changes: 1 addition & 1 deletion src/resources/common/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class DebugResource extends ReactiveResource<

switch (this.current.tag) {
case "connected": {
const searchIndexesSupported = await this.session.isSearchIndexSupported();
const searchIndexesSupported = await this.session.isSearchSupported();
result += `The user is connected to the MongoDB cluster${searchIndexesSupported ? " with support for search indexes" : " without any support for search indexes"}.`;
break;
}
Expand Down
140 changes: 130 additions & 10 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,158 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolCategory } from "../../tool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";

export class CreateIndexTool extends MongoDBToolBase {
private vectorSearchIndexDefinition = z.object({
type: z.literal("vectorSearch"),
fields: z
.array(
z.discriminatedUnion("type", [
z
.object({
type: z.literal("filter"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
})
.strict()
.describe("Definition for a field that will be used for pre-filtering results."),
z
.object({
type: z.literal("vector"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
numDimensions: z
.number()
.min(1)
.max(8192)
.default(this.config.vectorSearchDimensions)
.describe(
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
),
similarity: z
.enum(["cosine", "euclidean", "dotProduct"])
.default(this.config.vectorSearchSimilarityFunction)
.describe(
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
),
quantization: z
.enum(["none", "scalar", "binary"])
.optional()
.default("none")
.describe(
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
),
})
.strict()
.describe("Definition for a field that contains vector embeddings."),
])
)
.nonempty()
.refine((fields) => fields.some((f) => f.type === "vector"), {
message: "At least one vector field must be defined",
})
.describe(
"Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required."
),
});

public name = "create-index";
protected description = "Create an index for a collection";
protected argsShape = {
...DbOperationArgs,
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
name: z.string().optional().describe("The name of the index"),
definition: z
.array(
z.discriminatedUnion("type", [
z.object({
type: z.literal("classic"),
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
}),
...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [this.vectorSearchIndexDefinition] : []),
])
)
.describe(
"The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes"
),
};

public operationType: OperationType = "create";

protected async execute({
database,
collection,
keys,
name,
definition: definitions,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const indexes = await provider.createIndexes(database, collection, [
{
key: keys,
name,
},
]);
let indexes: string[] = [];
const definition = definitions[0];
if (!definition) {
throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`");
}

let responseClarification = "";

switch (definition.type) {
case "classic":
indexes = await provider.createIndexes(database, collection, [
{
key: definition.keys,
name,
},
]);
break;
case "vectorSearch":
{
const isVectorSearchSupported = await this.session.isSearchSupported();
if (!isVectorSearchSupported) {
// TODO: remove hacky casts once we merge the local dev tools
const isLocalAtlasAvailable =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking if a tool exists in a few places already in other places, maybe we can extract this to a function and refactor?

(this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory))
.length ?? 0) > 0;

const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI";
return {
content: [
{
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
type: "text",
},
],
isError: true,
};
}

indexes = await provider.createSearchIndexes(database, collection, [
{
name,
definition: {
fields: definition.fields,
},
type: "vectorSearch",
},
]);

responseClarification =
" Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status.";
}

break;
}

return {
content: [
{
text: `Created the index "${indexes[0]}" on collection "${collection}" in database "${database}"`,
text: `Created the index "${indexes[0]}" on collection "${collection}" in database "${database}".${responseClarification}`,
type: "text",
},
],
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const DbOperationArgs = {
};

export abstract class MongoDBToolBase extends ToolBase {
private server?: Server;
protected server?: Server;
public category: ToolCategory = "mongodb";

protected async ensureConnected(): Promise<NodeDriverServiceProvider> {
Expand Down
14 changes: 14 additions & 0 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ export type ToolCallbackArgs<Args extends ZodRawShape> = Parameters<ToolCallback

export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];

export const enum FeatureFlags {
VectorSearch = "vectorSearch",
}

/**
* The type of operation the tool performs. This is used when evaluating if a tool is allowed to run based on
* the config's `disabledTools` and `readOnly` settings.
Expand Down Expand Up @@ -314,6 +318,16 @@ export abstract class ToolBase {

this.telemetry.emitEvents([event]);
}

// TODO: Move this to a separate file
protected isFeatureFlagEnabled(flag: FeatureFlags): boolean {
switch (flag) {
case FeatureFlags.VectorSearch:
return this.config.voyageApiKey !== "";
default:
return false;
}
}
}

/**
Expand Down
Loading
Loading