diff --git a/pkg/generate/code/resource_reference.go b/pkg/generate/code/resource_reference.go index 10ab31f7..beec4012 100644 --- a/pkg/generate/code/resource_reference.go +++ b/pkg/generate/code/resource_reference.go @@ -35,7 +35,9 @@ func ReferenceFieldsValidation( indentLevel int, ) string { out := "" - for _, field := range crd.Fields { + // Sorted fieldnames are used for consistent code-generation + for _, fieldName := range crd.SortedFieldNames() { + field := crd.Fields[fieldName] if field.HasReference() { indent := strings.Repeat("\t", indentLevel) // Validation to make sure both target field and reference are @@ -75,7 +77,9 @@ func ReferenceFieldsPresent( sourceVarName string, ) string { out := "false" - for _, field := range crd.Fields { + // Sorted fieldnames are used for consistent code-generation + for _, fieldName := range crd.SortedFieldNames() { + field := crd.Fields[fieldName] if field.IsReference() { out += fmt.Sprintf(" || %s.Spec.%s != nil", sourceVarName, field.Names.Camel) diff --git a/pkg/generate/code/resource_reference_test.go b/pkg/generate/code/resource_reference_test.go index b6b26431..bc82f67d 100644 --- a/pkg/generate/code/resource_reference_test.go +++ b/pkg/generate/code/resource_reference_test.go @@ -76,6 +76,12 @@ func Test_ReferenceFieldsValidation_SliceOfReferences(t *testing.T) { ` if ko.Spec.SecurityGroupRefs != nil && ko.Spec.SecurityGroupIDs != nil { return ackerr.ResourceReferenceAndIDNotSupportedFor("SecurityGroupIDs", "SecurityGroupRefs") } + if ko.Spec.SubnetRefs != nil && ko.Spec.SubnetIDs != nil { + return ackerr.ResourceReferenceAndIDNotSupportedFor("SubnetIDs", "SubnetRefs") + } + if ko.Spec.SubnetRefs == nil && ko.Spec.SubnetIDs == nil { + return ackerr.ResourceReferenceOrIDRequiredFor("SubnetIDs", "SubnetRefs") + } ` assert.Equal(expected, code.ReferenceFieldsValidation(crd, "ko", 1)) } @@ -123,6 +129,6 @@ func Test_ReferenceFieldsPresent_SliceOfReferences(t *testing.T) { // just to test code generation for slices of reference crd := testutil.GetCRDByName(t, g, "VpcLink") require.NotNil(crd) - expected := "false || ko.Spec.SecurityGroupRefs != nil" + expected := "false || ko.Spec.SecurityGroupRefs != nil || ko.Spec.SubnetRefs != nil" assert.Equal(expected, code.ReferenceFieldsPresent(crd, "ko")) } diff --git a/pkg/model/crd.go b/pkg/model/crd.go index c5b78c5b..1405e971 100644 --- a/pkg/model/crd.go +++ b/pkg/model/crd.go @@ -794,6 +794,39 @@ func (r *CRD) HasReferenceFields() bool { return false } +// ReferencedServiceNames returns the set of service names for ACK controllers +// whose resources are referenced inside the CRD. The service name is +// the go package name for the AWS service inside aws-sdk-go. +// +// If a CRD has no reference fields, nil is returned(zero vale of slice) +func (r *CRD) ReferencedServiceNames() (serviceNames []string) { + // We are using Map to implement a Set of service names + serviceNamesMap := make(map[string]struct{}) + existsValue := struct{}{} + + for _, field := range r.Fields { + if serviceName := field.ReferencedServiceName(); serviceName != "" { + serviceNamesMap[serviceName] = existsValue + } + } + + for serviceName, _ := range serviceNamesMap { + serviceNames = append(serviceNames, serviceName) + } + return serviceNames +} + +// SortedFieldNames returns the fieldNames of the CRD in a sorted +// order. +func (r *CRD) SortedFieldNames() []string { + fieldNames := make([]string, 0, len(r.Fields)) + for fieldName := range r.Fields { + fieldNames = append(fieldNames, fieldName) + } + sort.Strings(fieldNames) + return fieldNames +} + // NewCRD returns a pointer to a new `ackmodel.CRD` struct that describes a // single top-level resource in an AWS service API func NewCRD( diff --git a/pkg/model/model_apigwv2_test.go b/pkg/model/model_apigwv2_test.go index 21969cb8..5d468424 100644 --- a/pkg/model/model_apigwv2_test.go +++ b/pkg/model/model_apigwv2_test.go @@ -65,6 +65,8 @@ func TestAPIGatewayV2_Api(t *testing.T) { // The required property should get overriden for Name and ProtocolType fields. assert.False(crd.SpecFields["Name"].IsRequired()) assert.False(crd.SpecFields["ProtocolType"].IsRequired()) + + assert.Nil(crd.ReferencedServiceNames()) } func TestAPIGatewayV2_Route(t *testing.T) { @@ -174,6 +176,8 @@ func TestAPIGatewayV2_Route(t *testing.T) { "RouteID", } assert.Equal(expStatusFieldCamel, attrCamelNames(statusFields)) + + assert.Nil(crd.ReferencedServiceNames()) } func TestAPIGatewayV2_WithReference(t *testing.T) { @@ -188,26 +192,37 @@ func TestAPIGatewayV2_WithReference(t *testing.T) { require.Nil(err) // Single reference - crd := getCRDByName("Integration", crds) - require.NotNil(crd) + integrationCrd := getCRDByName("Integration", crds) + require.NotNil(integrationCrd) - assert.Equal("Integration", crd.Names.Camel) - assert.Equal("integration", crd.Names.CamelLower) - assert.Equal("integration", crd.Names.Snake) + assert.Equal("Integration", integrationCrd.Names.Camel) + assert.Equal("integration", integrationCrd.Names.CamelLower) + assert.Equal("integration", integrationCrd.Names.Snake) - assert.NotNil(crd.SpecFields["ApiId"]) - assert.NotNil(crd.SpecFields["ApiRef"]) - assert.Equal("*ackv1alpha1.AWSResourceReferenceWrapper", crd.SpecFields["ApiRef"].GoType) + assert.NotNil(integrationCrd.SpecFields["ApiId"]) + assert.NotNil(integrationCrd.SpecFields["ApiRef"]) + assert.Equal("*ackv1alpha1.AWSResourceReferenceWrapper", integrationCrd.SpecFields["ApiRef"].GoType) - // List of References - crd = getCRDByName("VpcLink", crds) - require.NotNil(crd) - - assert.Equal("VPCLink", crd.Names.Camel) - assert.Equal("vpcLink", crd.Names.CamelLower) - assert.Equal("vpc_link", crd.Names.Snake) + referencedServiceNames := integrationCrd.ReferencedServiceNames() + assert.NotNil(referencedServiceNames) + assert.Contains(referencedServiceNames, "apigatewayv2") + assert.Equal(1, len(referencedServiceNames)) - assert.NotNil(crd.SpecFields["SecurityGroupIds"]) - assert.NotNil(crd.SpecFields["SecurityGroupRefs"]) - assert.Equal("[]*ackv1alpha1.AWSResourceReferenceWrapper", crd.SpecFields["SecurityGroupRefs"].GoType) + // List of References + vpcLinkCrd := getCRDByName("VpcLink", crds) + require.NotNil(vpcLinkCrd) + + assert.Equal("VPCLink", vpcLinkCrd.Names.Camel) + assert.Equal("vpcLink", vpcLinkCrd.Names.CamelLower) + assert.Equal("vpc_link", vpcLinkCrd.Names.Snake) + + assert.NotNil(vpcLinkCrd.SpecFields["SecurityGroupIds"]) + assert.NotNil(vpcLinkCrd.SpecFields["SecurityGroupRefs"]) + assert.Equal("[]*ackv1alpha1.AWSResourceReferenceWrapper", vpcLinkCrd.SpecFields["SecurityGroupRefs"].GoType) + + referencedServiceNames = vpcLinkCrd.ReferencedServiceNames() + assert.NotNil(referencedServiceNames) + assert.Contains(referencedServiceNames, "ec2") + assert.Contains(referencedServiceNames, "ec2-modified") + assert.Equal(2, len(referencedServiceNames)) } diff --git a/pkg/testdata/models/apis/apigatewayv2/0000-00-00/generator-with-reference.yaml b/pkg/testdata/models/apis/apigatewayv2/0000-00-00/generator-with-reference.yaml index f3a78631..697e3782 100644 --- a/pkg/testdata/models/apis/apigatewayv2/0000-00-00/generator-with-reference.yaml +++ b/pkg/testdata/models/apis/apigatewayv2/0000-00-00/generator-with-reference.yaml @@ -9,8 +9,14 @@ resources: fields: SecurityGroupIds: references: - resource: API - path: Status.APIID + resource: SecurityGroup + path: Status.ID + service_name: ec2 + SubnetIds: + references: + resource: Subnet + path: Status.SubnetID + service_name: ec2-modified #This is a dummy service name to validate multiple service references ignore: resource_names: - ApiMapping diff --git a/templates/pkg/resource/references.go.tpl b/templates/pkg/resource/references.go.tpl index 661fcdaf..bd2fb20c 100644 --- a/templates/pkg/resource/references.go.tpl +++ b/templates/pkg/resource/references.go.tpl @@ -21,9 +21,9 @@ import ( {{ $servicePackageName := .ServicePackageName -}} {{ $apiVersion := .APIVersion -}} {{ if .CRD.HasReferenceFields -}} -{{ range $fieldName, $field := .CRD.Fields -}} -{{ if and $field.HasReference (not (eq $field.ReferencedServiceName $servicePackageName)) -}} - {{ $field.ReferencedServiceName }}apitypes "github.com/aws-controllers-k8s/{{ $field.ReferencedServiceName }}-controller/apis/{{ $apiVersion }}" +{{ range $referencedServiceName := .CRD.ReferencedServiceNames -}} +{{ if not (eq $referencedServiceName $servicePackageName) -}} + {{ $referencedServiceName }}apitypes "github.com/aws-controllers-k8s/{{ $referencedServiceName }}-controller/apis/{{ $apiVersion }}" {{- end }} {{- end }} {{- end }}