Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ FIXES

* Fixing paths using in-built library instead of string manipulation. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143)
* Explicitly setting destructive annotation to false. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143)
* Fix provider search prioritization to show official providers first in search results. See [#179](https://github.com/hashicorp/terraform-mcp-server/pull/179)

SECURITY

Expand Down
9 changes: 8 additions & 1 deletion pkg/client/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func createHTTPClient(insecureSkipVerify bool, logger *log.Logger) *http.Client
return retryClient.StandardClient()
}

func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
// SendRegistryCallFn is a package-level function variable so callers (and tests)
// can override registry call behavior for testing.
var SendRegistryCallFn = func(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
ver := "v1"
if len(callOptions) > 0 {
ver = callOptions[0] // API version will be the first optional arg to this function
Expand Down Expand Up @@ -100,6 +102,11 @@ func SendRegistryCall(client *http.Client, method string, uri string, logger *lo
return body, nil
}

// Backwards-compatible wrapper
func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
return SendRegistryCallFn(client, method, uri, logger, callOptions...)
}

func SendPaginatedRegistryCall(client *http.Client, uriPrefix string, logger *log.Logger) ([]ProviderDocData, error) {
var results []ProviderDocData
page := 1
Expand Down
195 changes: 160 additions & 35 deletions pkg/tools/registry/search_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"path"
"sort"
"strings"

"github.com/hashicorp/terraform-mcp-server/pkg/client"
Expand All @@ -19,6 +20,26 @@ import (
"github.com/mark3labs/mcp-go/server"
)

// sendRegistryCall is a package-level variable so tests can override registry calls.
var sendRegistryCall = client.SendRegistryCall

// tierOrder defines sorting priority for provider tiers.
var tierOrder = map[string]int{"official": 0, "partner": 1, "community": 2}

type providerMatch struct {
Namespace string
Name string
Tier string
DocMatch []client.ProviderDoc
}

// sortMatchesByTier sorts the matches slice in-place by tier using tierOrder.
func sortMatchesByTier(matches []providerMatch) {
sort.SliceStable(matches, func(i, j int) bool {
return tierOrder[strings.ToLower(matches[i].Tier)] < tierOrder[strings.ToLower(matches[j].Tier)]
})
}

// ResolveProviderDocID creates a tool to get provider details from registry.
func ResolveProviderDocID(logger *log.Logger) server.ServerTool {
return server.ServerTool{
Expand All @@ -27,8 +48,8 @@ func ResolveProviderDocID(logger *log.Logger) server.ServerTool {
You MUST call this function before 'get_provider_details' to obtain a valid tfprovider-compatible provider_doc_id.
Use the most relevant single word as the search query for service_slug, if unsure about the service_slug, use the provider_name for its value.
When selecting the best match, consider the following:
- Title similarity to the query
- Category relevance
- Title similarity to the query
- Category relevance
Return the selected provider_doc_id and explain your choice.
If there are multiple good matches, mention this but proceed with the most relevant one.`),
mcp.WithTitleAnnotation("Identify the most relevant provider document ID for a Terraform service"),
Expand Down Expand Up @@ -92,56 +113,161 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques
if utils.IsV2ProviderDataType(providerDetail.ProviderDataType) {
content, err := providerDetailsV2(httpClient, providerDetail, logger)
if err != nil {
errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`,
providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide)
errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide)
return nil, utils.LogAndReturnError(logger, errMessage, err)
}

fullContent := fmt.Sprintf("# %s provider docs\n\n%s",
providerDetail.ProviderName, content)
fullContent := fmt.Sprintf("# %s provider docs\n\n%s", providerDetail.ProviderName, content)

return mcp.NewToolResultText(fullContent), nil
}

// For resources/data-sources, use the v1 API for better performance (single response)
uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
response, err := client.SendRegistryCall(httpClient, "GET", uri, logger)
// Delegate to extracted helper so it can be unit-tested.
result, err := searchProvidersDocs(httpClient, providerDetail, serviceSlug, defaultErrorGuide, logger)
if err != nil {
return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil)
return nil, err
}
return mcp.NewToolResultText(result), nil
}

var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
return nil, utils.LogAndReturnError(logger, "unmarshalling provider docs", err)
// searchProvidersDocs contains the core provider-search and prioritization logic.
// It returns the textual result (same content as the tool would return) for easier unit testing.
func searchProvidersDocs(httpClient *http.Client, providerDetail client.ProviderDetail, serviceSlug string, defaultErrorGuide string, logger *log.Logger) (string, error) {
// Enhanced: Search all providers matching the name and prioritize by tier
searchUri := "providers?filter[name]=" + providerDetail.ProviderName
searchResp, err := sendRegistryCall(httpClient, "GET", searchUri, logger, "v2")
if err != nil {
return "", utils.LogAndReturnError(logger, "error searching providers in registry", err)
}

var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
var providerList client.ProviderList
if err := json.Unmarshal(searchResp, &providerList); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider list", err)
}

// If the registry search didn't return any providers, fall back to fetching
// the single provider directly (preserves previous behavior for cases where
// provider namespace defaults to hashicorp and the search endpoint may not
// return results matching our filter).
logger.Infof("provider search returned %d providers for name '%s'", len(providerList.Data), providerDetail.ProviderName)
if len(providerList.Data) == 0 {
logger.Infof("falling back to single-provider fetch for %s/%s@%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
logger.Debugf("provider docs fetch URI: %s", uri)
if err != nil {
return "", utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil)
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider docs", err)
}
logger.Infof("provider docs returned %d docs for %s/%s@%s", len(providerDocs.Docs), providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
contentAvailable := false
for _, doc := range providerDocs.Docs {
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
contentAvailable = true
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
}
}

if !contentAvailable {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return "", utils.LogAndReturnError(logger, errMessage, err)
}
return builder.String(), nil
}

var matches []providerMatch

contentAvailable := false
for _, doc := range providerDocs.Docs {
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
contentAvailable = true
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
for _, pdata := range providerList.Data {
namespace := pdata.Attributes.Namespace
name := pdata.Attributes.Name
tier := pdata.Attributes.Tier
logger.Debugf("search provider entry: namespace=%s name=%s tier=%s", namespace, name, tier)

// Get docs for this provider. Try the requested version first; if that
// fails (for example the version doesn't exist in this namespace), try
// to resolve the latest version for that namespace/name and retry.
uri := path.Join("providers", namespace, name, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
// Attempt to fetch the latest provider version for this namespace/name
latestVer, verErr := client.GetLatestProviderVersion(httpClient, namespace, name, logger)
if verErr != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs: %v (also failed to get latest version: %v)", namespace, name, err, verErr)
continue // skip providers we can't fetch
}
uri = path.Join("providers", namespace, name, latestVer)
response, err = sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs with latest version %s: %v", namespace, name, latestVer, err)
continue
}
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
logger.Debugf("skipping provider %s/%s: error unmarshalling docs: %v", namespace, name, err)
continue
}
logger.Debugf("fetched %d docs for provider %s/%s", len(providerDocs.Docs), namespace, name)
var docMatches []client.ProviderDoc
for _, doc := range providerDocs.Docs {
logger.Tracef("considering doc slug=%s title=%s category=%s language=%s", doc.Slug, doc.Title, doc.Category, doc.Language)
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", name, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
logger.Debugf("matched doc %s for provider %s/%s (slug=%s)", doc.ID, namespace, name, doc.Slug)
docMatches = append(docMatches, doc)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
}
if len(docMatches) > 0 {
matches = append(matches, providerMatch{
Namespace: namespace,
Name: name,
Tier: tier,
DocMatch: docMatches,
})
}
}

// Check if the content data is not fulfilled
if !contentAvailable {
if len(matches) == 0 {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return nil, utils.LogAndReturnError(logger, errMessage, err)
return "", utils.LogAndReturnError(logger, errMessage, err)
}

// Sort matches by tier
sortMatchesByTier(matches)

var builder strings.Builder
builder.WriteString("Available Documentation (prioritized by provider tier)\n\n")
builder.WriteString("Tier order: official > partner > community\n\n")
for _, match := range matches {
builder.WriteString(fmt.Sprintf("Provider: %s/%s (Tier: %s)\n", match.Namespace, match.Name, match.Tier))
for _, doc := range match.DocMatch {
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
builder.WriteString("\n")
}
return mcp.NewToolResultText(builder.String()), nil
return builder.String(), nil
}

func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) {
Expand Down Expand Up @@ -214,8 +340,7 @@ func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDe
return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger)
}

uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl",
providerVersionID, category)
uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", providerVersionID, category)

docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger)
if err != nil {
Expand Down Expand Up @@ -270,4 +395,4 @@ func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger
return desc[:300] + "...", nil
}
return desc, nil
}
}
59 changes: 59 additions & 0 deletions pkg/tools/registry/search_providers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package tools

import (
"net/http"
"strings"
"testing"

log "github.com/sirupsen/logrus"
"github.com/hashicorp/terraform-mcp-server/pkg/client"
)

func TestSearchProvidersPrioritizesOfficial(t *testing.T) {
// Backup original and restore
original := sendRegistryCall
defer func() { sendRegistryCall = original }()

// Fake responses
sendRegistryCall = func(httpClient *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
// provider list call
if strings.HasPrefix(uri, "providers?filter[name]=") {
// Return two providers: community then official (unordered)
// minimal JSON with attributes name, namespace, tier
return []byte(`{"data":[{"id":"1","attributes":{"name":"keycloak","namespace":"mrparkers","tier":"community"}},{"id":"2","attributes":{"name":"keycloak","namespace":"keycloak-official","tier":"official"}}]}`), nil
}

// provider docs calls: uri like providers/{namespace}/{name}/{version}
if strings.HasPrefix(uri, "providers/mrparkers/") {
return []byte(`{"docs":[{"id":"doc1","title":"Keycloak (community)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil
}
if strings.HasPrefix(uri, "providers/keycloak-official/") {
return []byte(`{"docs":[{"id":"doc2","title":"Keycloak (official)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil
}

return nil, nil
}

logger := log.New()
providerDetail := client.ProviderDetail{
ProviderName: "keycloak",
ProviderNamespace: "",
ProviderVersion: "latest",
ProviderDataType: "resources",
}

result, err := searchProvidersDocs(http.DefaultClient, providerDetail, "keycloak", "default guide", logger)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// official provider should appear before community provider in the output
officialIdx := strings.Index(result, "Provider: keycloak-official/keycloak (Tier: official)")
communityIdx := strings.Index(result, "Provider: mrparkers/keycloak (Tier: community)")
if officialIdx == -1 || communityIdx == -1 {
t.Fatalf("expected both providers in result, got: %s", result)
}
if officialIdx > communityIdx {
t.Fatalf("official provider found after community provider; result: %s", result)
}
}
19 changes: 19 additions & 0 deletions pkg/tools/registry/sort_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package tools

import (
"testing"
)

func TestSortMatchesByTier(t *testing.T) {
matches := []providerMatch{
{Namespace: "a", Name: "one", Tier: "community"},
{Namespace: "b", Name: "two", Tier: "partner"},
{Namespace: "c", Name: "three", Tier: "official"},
}

sortMatchesByTier(matches)

if matches[0].Tier != "official" || matches[1].Tier != "partner" || matches[2].Tier != "community" {
t.Fatalf("unexpected tier order: %v", []string{matches[0].Tier, matches[1].Tier, matches[2].Tier})
}
}
Loading