Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ FIXES

* Fixing paths using in-built library instead of string manipulation. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143)
* Explicitly setting destructive annotation to false. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143)
* Fix provider search prioritization to show official providers first in search results. See [#179](https://github.com/hashicorp/terraform-mcp-server/pull/179)

SECURITY

Expand Down
9 changes: 8 additions & 1 deletion pkg/client/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func createHTTPClient(insecureSkipVerify bool, logger *log.Logger) *http.Client
return retryClient.StandardClient()
}

func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
// SendRegistryCallFn is a package-level function variable so callers (and tests)
// can override registry call behavior for testing.
var SendRegistryCallFn = func(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
ver := "v1"
if len(callOptions) > 0 {
ver = callOptions[0] // API version will be the first optional arg to this function
Expand Down Expand Up @@ -100,6 +102,11 @@ func SendRegistryCall(client *http.Client, method string, uri string, logger *lo
return body, nil
}

// Backwards-compatible wrapper
func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
return SendRegistryCallFn(client, method, uri, logger, callOptions...)
}

func SendPaginatedRegistryCall(client *http.Client, uriPrefix string, logger *log.Logger) ([]ProviderDocData, error) {
var results []ProviderDocData
page := 1
Expand Down
195 changes: 160 additions & 35 deletions pkg/tools/registry/search_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"path"
"sort"
"strings"

"github.com/hashicorp/terraform-mcp-server/pkg/client"
Expand All @@ -19,6 +20,26 @@ import (
"github.com/mark3labs/mcp-go/server"
)

// sendRegistryCall is a package-level variable so tests can override registry calls.
var sendRegistryCall = client.SendRegistryCall

// tierOrder defines sorting priority for provider tiers.
var tierOrder = map[string]int{"official": 0, "partner": 1, "community": 2}

type providerMatch struct {
Namespace string
Name string
Tier string
DocMatch []client.ProviderDoc
}

// sortMatchesByTier sorts the matches slice in-place by tier using tierOrder.
func sortMatchesByTier(matches []providerMatch) {
sort.SliceStable(matches, func(i, j int) bool {
return tierOrder[strings.ToLower(matches[i].Tier)] < tierOrder[strings.ToLower(matches[j].Tier)]
})
}

// ResolveProviderDocID creates a tool to get provider details from registry.
func ResolveProviderDocID(logger *log.Logger) server.ServerTool {
return server.ServerTool{
Expand All @@ -27,8 +48,8 @@ func ResolveProviderDocID(logger *log.Logger) server.ServerTool {
You MUST call this function before 'get_provider_details' to obtain a valid tfprovider-compatible provider_doc_id.
Use the most relevant single word as the search query for service_slug, if unsure about the service_slug, use the provider_name for its value.
When selecting the best match, consider the following:
- Title similarity to the query
- Category relevance
- Title similarity to the query
- Category relevance
Return the selected provider_doc_id and explain your choice.
If there are multiple good matches, mention this but proceed with the most relevant one.`),
mcp.WithTitleAnnotation("Identify the most relevant provider document ID for a Terraform service"),
Expand Down Expand Up @@ -92,56 +113,161 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques
if utils.IsV2ProviderDataType(providerDetail.ProviderDataType) {
content, err := providerDetailsV2(httpClient, providerDetail, logger)
if err != nil {
errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`,
providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide)
errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide)
return nil, utils.LogAndReturnError(logger, errMessage, err)
}

fullContent := fmt.Sprintf("# %s provider docs\n\n%s",
providerDetail.ProviderName, content)
fullContent := fmt.Sprintf("# %s provider docs\n\n%s", providerDetail.ProviderName, content)

return mcp.NewToolResultText(fullContent), nil
}

// For resources/data-sources, use the v1 API for better performance (single response)
uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
response, err := client.SendRegistryCall(httpClient, "GET", uri, logger)
// Delegate to extracted helper so it can be unit-tested.
result, err := searchProvidersDocs(httpClient, providerDetail, serviceSlug, defaultErrorGuide, logger)
if err != nil {
return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil)
return nil, err
}
return mcp.NewToolResultText(result), nil
}

var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
return nil, utils.LogAndReturnError(logger, "unmarshalling provider docs", err)
// searchProvidersDocs contains the core provider-search and prioritization logic.
// It returns the textual result (same content as the tool would return) for easier unit testing.
func searchProvidersDocs(httpClient *http.Client, providerDetail client.ProviderDetail, serviceSlug string, defaultErrorGuide string, logger *log.Logger) (string, error) {
// Enhanced: Search all providers matching the name and prioritize by tier
searchUri := "providers?filter[name]=" + providerDetail.ProviderName
searchResp, err := sendRegistryCall(httpClient, "GET", searchUri, logger, "v2")
if err != nil {
return "", utils.LogAndReturnError(logger, "error searching providers in registry", err)
}

var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
var providerList client.ProviderList
if err := json.Unmarshal(searchResp, &providerList); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider list", err)
}

// If the registry search didn't return any providers, fall back to fetching
// the single provider directly (preserves previous behavior for cases where
// provider namespace defaults to hashicorp and the search endpoint may not
// return results matching our filter).
logger.Infof("provider search returned %d providers for name '%s'", len(providerList.Data), providerDetail.ProviderName)
if len(providerList.Data) == 0 {
logger.Infof("falling back to single-provider fetch for %s/%s@%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
logger.Debugf("provider docs fetch URI: %s", uri)
if err != nil {
return "", utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil)
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider docs", err)
}
logger.Infof("provider docs returned %d docs for %s/%s@%s", len(providerDocs.Docs), providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
contentAvailable := false
for _, doc := range providerDocs.Docs {
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
contentAvailable = true
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
}
}

if !contentAvailable {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return "", utils.LogAndReturnError(logger, errMessage, err)
}
return builder.String(), nil
}

var matches []providerMatch

contentAvailable := false
for _, doc := range providerDocs.Docs {
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
contentAvailable = true
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
for _, pdata := range providerList.Data {
namespace := pdata.Attributes.Namespace
name := pdata.Attributes.Name
tier := pdata.Attributes.Tier
logger.Debugf("search provider entry: namespace=%s name=%s tier=%s", namespace, name, tier)

// Get docs for this provider. Try the requested version first; if that
// fails (for example the version doesn't exist in this namespace), try
// to resolve the latest version for that namespace/name and retry.
uri := path.Join("providers", namespace, name, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
// Attempt to fetch the latest provider version for this namespace/name
latestVer, verErr := client.GetLatestProviderVersion(httpClient, namespace, name, logger)
if verErr != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs: %v (also failed to get latest version: %v)", namespace, name, err, verErr)
continue // skip providers we can't fetch
}
uri = path.Join("providers", namespace, name, latestVer)
response, err = sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs with latest version %s: %v", namespace, name, latestVer, err)
continue
}
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
logger.Debugf("skipping provider %s/%s: error unmarshalling docs: %v", namespace, name, err)
continue
}
logger.Debugf("fetched %d docs for provider %s/%s", len(providerDocs.Docs), namespace, name)
var docMatches []client.ProviderDoc
for _, doc := range providerDocs.Docs {
logger.Tracef("considering doc slug=%s title=%s category=%s language=%s", doc.Slug, doc.Title, doc.Category, doc.Language)
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", name, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
logger.Debugf("matched doc %s for provider %s/%s (slug=%s)", doc.ID, namespace, name, doc.Slug)
docMatches = append(docMatches, doc)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
}
if len(docMatches) > 0 {
matches = append(matches, providerMatch{
Namespace: namespace,
Name: name,
Tier: tier,
DocMatch: docMatches,
})
}
}

// Check if the content data is not fulfilled
if !contentAvailable {
if len(matches) == 0 {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return nil, utils.LogAndReturnError(logger, errMessage, err)
return "", utils.LogAndReturnError(logger, errMessage, err)
}

// Sort matches by tier
sortMatchesByTier(matches)

var builder strings.Builder
builder.WriteString("Available Documentation (prioritized by provider tier)\n\n")
builder.WriteString("Tier order: official > partner > community\n\n")
for _, match := range matches {
builder.WriteString(fmt.Sprintf("Provider: %s/%s (Tier: %s)\n", match.Namespace, match.Name, match.Tier))
for _, doc := range match.DocMatch {
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
builder.WriteString("\n")
}
return mcp.NewToolResultText(builder.String()), nil
return builder.String(), nil
}

func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) {
Expand Down Expand Up @@ -214,8 +340,7 @@ func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDe
return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger)
}

uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl",
providerVersionID, category)
uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", providerVersionID, category)

docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger)
if err != nil {
Expand Down Expand Up @@ -270,4 +395,4 @@ func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger
return desc[:300] + "...", nil
}
return desc, nil
}
}
59 changes: 59 additions & 0 deletions pkg/tools/registry/search_providers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package tools

import (
"net/http"
"strings"
"testing"

log "github.com/sirupsen/logrus"
"github.com/hashicorp/terraform-mcp-server/pkg/client"
)

func TestSearchProvidersPrioritizesOfficial(t *testing.T) {
// Backup original and restore
original := sendRegistryCall
defer func() { sendRegistryCall = original }()

// Fake responses
sendRegistryCall = func(httpClient *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) {
// provider list call
if strings.HasPrefix(uri, "providers?filter[name]=") {
// Return two providers: community then official (unordered)
// minimal JSON with attributes name, namespace, tier
return []byte(`{"data":[{"id":"1","attributes":{"name":"keycloak","namespace":"mrparkers","tier":"community"}},{"id":"2","attributes":{"name":"keycloak","namespace":"keycloak-official","tier":"official"}}]}`), nil
}

// provider docs calls: uri like providers/{namespace}/{name}/{version}
if strings.HasPrefix(uri, "providers/mrparkers/") {
return []byte(`{"docs":[{"id":"doc1","title":"Keycloak (community)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil
}
if strings.HasPrefix(uri, "providers/keycloak-official/") {
return []byte(`{"docs":[{"id":"doc2","title":"Keycloak (official)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil
}

return nil, nil
}

logger := log.New()
providerDetail := client.ProviderDetail{
ProviderName: "keycloak",
ProviderNamespace: "",
ProviderVersion: "latest",
ProviderDataType: "resources",
}

result, err := searchProvidersDocs(http.DefaultClient, providerDetail, "keycloak", "default guide", logger)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// official provider should appear before community provider in the output
officialIdx := strings.Index(result, "Provider: keycloak-official/keycloak (Tier: official)")
communityIdx := strings.Index(result, "Provider: mrparkers/keycloak (Tier: community)")
if officialIdx == -1 || communityIdx == -1 {
t.Fatalf("expected both providers in result, got: %s", result)
}
if officialIdx > communityIdx {
t.Fatalf("official provider found after community provider; result: %s", result)
}
}
19 changes: 19 additions & 0 deletions pkg/tools/registry/sort_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package tools

import (
"testing"
)

func TestSortMatchesByTier(t *testing.T) {
matches := []providerMatch{
{Namespace: "a", Name: "one", Tier: "community"},
{Namespace: "b", Name: "two", Tier: "partner"},
{Namespace: "c", Name: "three", Tier: "official"},
}

sortMatchesByTier(matches)

if matches[0].Tier != "official" || matches[1].Tier != "partner" || matches[2].Tier != "community" {
t.Fatalf("unexpected tier order: %v", []string{matches[0].Tier, matches[1].Tier, matches[2].Tier})
}
}
Loading