diff --git a/backend/operations_scanner.go b/backend/operations_scanner.go index 00e50076b..487b90599 100644 --- a/backend/operations_scanner.go +++ b/backend/operations_scanner.go @@ -165,7 +165,7 @@ func (s *OperationsScanner) pollClusterOperation(ctx context.Context, logger *sl if err != nil { var ocmError *ocmerrors.Error if errors.As(err, &ocmError) && ocmError.Status() == http.StatusNotFound && doc.Request == database.OperationRequestDelete { - err = s.withSubscriptionLock(ctx, logger, doc.OperationID.SubscriptionID, func(ctx context.Context) error { + err = s.withSubscriptionLock(ctx, logger, doc.ExternalID.SubscriptionID, func(ctx context.Context) error { return s.deleteOperationCompleted(ctx, logger, doc) }) if err == nil { @@ -180,7 +180,7 @@ func (s *OperationsScanner) pollClusterOperation(ctx context.Context, logger *sl logger.Warn(err.Error()) err = nil } else { - err = s.withSubscriptionLock(ctx, logger, doc.OperationID.SubscriptionID, func(ctx context.Context) error { + err = s.withSubscriptionLock(ctx, logger, doc.ExternalID.SubscriptionID, func(ctx context.Context) error { return s.updateOperationStatus(ctx, logger, doc, opStatus, opError) }) } diff --git a/backend/operations_scanner_test.go b/backend/operations_scanner_test.go index a84c5118f..0606bb657 100644 --- a/backend/operations_scanner_test.go +++ b/backend/operations_scanner_test.go @@ -15,6 +15,7 @@ import ( "github.com/Azure/ARO-HCP/internal/api/arm" "github.com/Azure/ARO-HCP/internal/database" + "github.com/Azure/ARO-HCP/internal/ocm" ) func TestDeleteOperationCompleted(t *testing.T) { @@ -48,6 +49,12 @@ func TestDeleteOperationCompleted(t *testing.T) { }, } + // Placeholder InternalID for NewOperationDocument + internalID, err := ocm.NewInternalID("/api/clusters_mgmt/v1/clusters/placeholder") + if err != nil { + t.Fatal(err) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var request *http.Request @@ -71,8 +78,7 @@ func TestDeleteOperationCompleted(t *testing.T) { notificationClient: server.Client(), } - operationDoc := database.NewOperationDocument(database.OperationRequestDelete) - operationDoc.ExternalID = resourceID + operationDoc := database.NewOperationDocument(database.OperationRequestDelete, resourceID, internalID) operationDoc.NotificationURI = server.URL operationDoc.Status = tt.operationStatus @@ -190,6 +196,12 @@ func TestUpdateOperationStatus(t *testing.T) { }, } + // Placeholder InternalID for NewOperationDocument + internalID, err := ocm.NewInternalID("/api/clusters_mgmt/v1/clusters/placeholder") + if err != nil { + t.Fatal(err) + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var request *http.Request @@ -213,8 +225,7 @@ func TestUpdateOperationStatus(t *testing.T) { notificationClient: server.Client(), } - operationDoc := database.NewOperationDocument(database.OperationRequestCreate) - operationDoc.ExternalID = resourceID + operationDoc := database.NewOperationDocument(database.OperationRequestCreate, resourceID, internalID) operationDoc.NotificationURI = server.URL operationDoc.Status = tt.currentOperationStatus diff --git a/frontend/pkg/frontend/frontend.go b/frontend/pkg/frontend/frontend.go index ea7becf46..be8e3450a 100644 --- a/frontend/pkg/frontend/frontend.go +++ b/frontend/pkg/frontend/frontend.go @@ -192,17 +192,29 @@ func (f *Frontend) ArmResourceList(writer http.ResponseWriter, request *http.Req return } - documentList, continuationToken, err := f.dbClient.ListResourceDocs(ctx, prefix, &api.ClusterResourceType, pageSizeHint, continuationToken) - if err != nil { - f.logger.Error(err.Error()) - arm.WriteInternalServerError(writer) - return - } + iterator := f.dbClient.ListResourceDocs(ctx, prefix, pageSizeHint, continuationToken) // Build a map of cluster documents by Cluster Service cluster ID. documentMap := make(map[string]*database.ResourceDocument) - for _, doc := range documentList { - documentMap[doc.InternalID.ID()] = doc + for item := range iterator.Items(ctx) { + var doc database.ResourceDocument + + err = json.Unmarshal(item, &doc) + if err != nil { + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) + return + } + + if strings.EqualFold(doc.Key.ResourceType.String(), api.ClusterResourceType.String()) { + documentMap[doc.InternalID.ID()] = &doc + } + } + + err = iterator.GetError() + if err != nil { + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) } // Build a Cluster Service query that looks for @@ -240,13 +252,11 @@ func (f *Frontend) ArmResourceList(writer http.ResponseWriter, request *http.Req } } - if continuationToken != nil { - err = pagedResponse.SetNextLink(request.Referer(), *continuationToken) - if err != nil { - f.logger.Error(err.Error()) - arm.WriteInternalServerError(writer) - return - } + err = pagedResponse.SetNextLink(request.Referer(), iterator.GetContinuationToken()) + if err != nil { + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) + return } _, err = arm.WriteJSONResponse(writer, http.StatusOK, pagedResponse) @@ -464,9 +474,18 @@ func (f *Frontend) ArmResourceCreateOrUpdate(writer http.ResponseWriter, request } } - operationDoc, err := f.StartOperation(writer, request, doc, operationRequest) + operationDoc := database.NewOperationDocument(operationRequest, doc.Key, doc.InternalID) + + err = f.dbClient.CreateOperationDoc(ctx, operationDoc) if err != nil { - f.logger.Error(fmt.Sprintf("failed to write operation document: %v", err)) + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) + return + } + + err = f.ExposeOperation(writer, request, operationDoc.ID) + if err != nil { + f.logger.Error(err.Error()) arm.WriteInternalServerError(writer) return } @@ -533,6 +552,8 @@ func (f *Frontend) ArmResourceCreateOrUpdate(writer http.ResponseWriter, request // * 202 if an asynchronous delete is initiated // * 204 if a well-formed request attempts to delete a nonexistent resource func (f *Frontend) ArmResourceDelete(writer http.ResponseWriter, request *http.Request) { + const operationRequest = database.OperationRequestDelete + ctx := request.Context() versionedInterface, err := VersionFromContext(ctx) @@ -551,70 +572,45 @@ func (f *Frontend) ArmResourceDelete(writer http.ResponseWriter, request *http.R f.logger.Info(fmt.Sprintf("%s: ArmResourceDelete", versionedInterface)) - resourceDoc, cloudError := f.DeleteResource(ctx, resourceID) - if cloudError != nil { + resourceDoc, err := f.dbClient.GetResourceDoc(ctx, resourceID) + if err != nil { // For resource not found errors on deletion, ARM requires // us to simply return 204 No Content and no response body. - if cloudError.StatusCode == http.StatusNotFound { + if errors.Is(err, database.ErrNotFound) { writer.WriteHeader(http.StatusNoContent) } else { - arm.WriteCloudError(writer, cloudError) + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) } return } - operationRequest := database.OperationRequestDelete - // CheckForProvisioningStateConflict does not log conflict errors // but does log unexpected errors like database failures. - cloudError = f.CheckForProvisioningStateConflict(ctx, operationRequest, resourceDoc) + cloudError := f.CheckForProvisioningStateConflict(ctx, operationRequest, resourceDoc) if cloudError != nil { arm.WriteCloudError(writer, cloudError) return } - err = f.clusterServiceClient.DeleteCSCluster(ctx, resourceDoc.InternalID) - if err != nil { - f.logger.Error(fmt.Sprintf("failed to delete cluster %s: %v", resourceID, err)) - arm.WriteInternalServerError(writer) - return - } - - // Deletion is underway; mark any active operation as canceled. - if resourceDoc.ActiveOperationID != "" { - updated, err := f.dbClient.UpdateOperationDoc(ctx, resourceDoc.ActiveOperationID, func(updateDoc *database.OperationDocument) bool { - return updateDoc.UpdateStatus(arm.ProvisioningStateCanceled, nil) - }) - if err != nil { - f.logger.Error(err.Error()) - arm.WriteInternalServerError(writer) - return - } - if updated { - f.logger.Info(fmt.Sprintf("canceled operation '%s'", resourceDoc.ActiveOperationID)) + operationID, cloudError := f.DeleteResource(ctx, resourceDoc) + if cloudError != nil { + // For resource not found errors on deletion, ARM requires + // us to simply return 204 No Content and no response body. + if cloudError.StatusCode == http.StatusNotFound { + writer.WriteHeader(http.StatusNoContent) + } else { + arm.WriteCloudError(writer, cloudError) } - } - - operationDoc, err := f.StartOperation(writer, request, resourceDoc, operationRequest) - if err != nil { - f.logger.Error(fmt.Sprintf("failed to write operation document: %v", err)) - arm.WriteInternalServerError(writer) return } - updated, err := f.dbClient.UpdateResourceDoc(ctx, resourceID, func(updateDoc *database.ResourceDocument) bool { - updateDoc.ActiveOperationID = operationDoc.ID - updateDoc.ProvisioningState = operationDoc.Status - return true - }) + err = f.ExposeOperation(writer, request, operationID) if err != nil { f.logger.Error(err.Error()) arm.WriteInternalServerError(writer) return } - if updated { - f.logger.Info(fmt.Sprintf("document updated for %s", resourceID)) - } writer.WriteHeader(http.StatusAccepted) } @@ -730,7 +726,16 @@ func (f *Frontend) ArmSubscriptionPut(writer http.ResponseWriter, request *http. "state": string(subscription.State), }) - _, err = arm.WriteJSONResponse(writer, http.StatusCreated, subscription) + // Clean up resources if subscription is deleted. + if subscription.State == arm.SubscriptionStateDeleted { + cloudError := f.DeleteAllResources(ctx, subscriptionID) + if cloudError != nil { + arm.WriteCloudError(writer, cloudError) + return + } + } + + _, err = arm.WriteJSONResponse(writer, http.StatusOK, subscription) if err != nil { f.logger.Error(err.Error()) } diff --git a/frontend/pkg/frontend/frontend_test.go b/frontend/pkg/frontend/frontend_test.go index ad991a25b..8efbe0bf7 100644 --- a/frontend/pkg/frontend/frontend_test.go +++ b/frontend/pkg/frontend/frontend_test.go @@ -139,7 +139,7 @@ func TestSubscriptionsPUT(t *testing.T) { Properties: nil, }, subDoc: nil, - expectedStatusCode: http.StatusCreated, + expectedStatusCode: http.StatusOK, }, { name: "PUT Subscription - Doc Exists", @@ -159,7 +159,7 @@ func TestSubscriptionsPUT(t *testing.T) { Properties: nil, }, }, - expectedStatusCode: http.StatusCreated, + expectedStatusCode: http.StatusOK, }, { name: "PUT Subscription - Invalid Subscription", diff --git a/frontend/pkg/frontend/helpers.go b/frontend/pkg/frontend/helpers.go index 11dd7fadf..dfbe73635 100644 --- a/frontend/pkg/frontend/helpers.go +++ b/frontend/pkg/frontend/helpers.go @@ -5,6 +5,7 @@ package frontend import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -69,39 +70,157 @@ func (f *Frontend) CheckForProvisioningStateConflict(ctx context.Context, operat return nil } -func (f *Frontend) DeleteResource(ctx context.Context, resourceID *arm.ResourceID) (*database.ResourceDocument, *arm.CloudError) { - doc, err := f.dbClient.GetResourceDoc(ctx, resourceID) +func (f *Frontend) DeleteAllResources(ctx context.Context, subscriptionID string) *arm.CloudError { + prefix, err := arm.ParseResourceID("/subscriptions/" + subscriptionID) if err != nil { - if errors.Is(err, database.ErrNotFound) { - return nil, arm.NewResourceNotFoundError(resourceID) - } else { + f.logger.Error(err.Error()) + return arm.NewInternalServerError() + } + + dbIterator := f.dbClient.ListResourceDocs(ctx, prefix, -1, nil) + + // Start a deletion operation for all clusters under the subscription. + // Cluster Service will delete all node pools belonging to these clusters + // so we don't need to explicitly delete node pools here. + for item := range dbIterator.Items(ctx) { + var resourceDoc *database.ResourceDocument + + err = json.Unmarshal(item, &resourceDoc) + if err != nil { f.logger.Error(err.Error()) - return nil, arm.NewInternalServerError() + return arm.NewInternalServerError() + } + + if !strings.EqualFold(resourceDoc.Key.ResourceType.String(), api.ClusterResourceType.String()) { + continue + } + + // Allow this method to be idempotent. + if resourceDoc.ProvisioningState != arm.ProvisioningStateDeleting { + _, cloudError := f.DeleteResource(ctx, resourceDoc) + if cloudError != nil { + return cloudError + } } } - switch doc.InternalID.Kind() { + return nil +} + +func (f *Frontend) DeleteResource(ctx context.Context, resourceDoc *database.ResourceDocument) (string, *arm.CloudError) { + const operationRequest = database.OperationRequestDelete + var err error + + switch resourceDoc.InternalID.Kind() { case cmv1.ClusterKind: - err = f.clusterServiceClient.DeleteCSCluster(ctx, doc.InternalID) + err = f.clusterServiceClient.DeleteCSCluster(ctx, resourceDoc.InternalID) case cmv1.NodePoolKind: - err = f.clusterServiceClient.DeleteCSNodePool(ctx, doc.InternalID) + err = f.clusterServiceClient.DeleteCSNodePool(ctx, resourceDoc.InternalID) default: - f.logger.Error(fmt.Sprintf("unsupported Cluster Service path: %s", doc.InternalID)) - return nil, arm.NewInternalServerError() + f.logger.Error(fmt.Sprintf("unsupported Cluster Service path: %s", resourceDoc.InternalID)) + return "", arm.NewInternalServerError() } if err != nil { var ocmError *ocmerrors.Error if errors.As(err, &ocmError) && ocmError.Status() == http.StatusNotFound { - return nil, arm.NewResourceNotFoundError(resourceID) + return "", arm.NewResourceNotFoundError(resourceDoc.Key) } f.logger.Error(err.Error()) - return nil, arm.NewInternalServerError() + return "", arm.NewInternalServerError() + } + + // Cluster Service will take care of canceling any ongoing operations + // on the resource or child resources, but we need to do some database + // bookkeeping to reflect that. + + // FIXME This would be a good place to use Cosmos DB's transactional batch + // operations to ensure all these write operations succeed together + // or roll back. We would need two parallel transactions: one for + // the Operations container and another for the Resources container. + // But we're stymied currently by the DBClient interface, and I have + // no desire to implement this in the in-memory cache. DBClient has + // served us well up to this point, but I think it's time to bid it + // farewell and switch to gomock in unit tests. + + err = f.CancelActiveOperation(ctx, resourceDoc) + if err != nil { + f.logger.Error(err.Error()) + return "", arm.NewInternalServerError() + } + + operationDoc := database.NewOperationDocument(operationRequest, resourceDoc.Key, resourceDoc.InternalID) + + err = f.dbClient.CreateOperationDoc(ctx, operationDoc) + if err != nil { + f.logger.Error(err.Error()) + return "", arm.NewInternalServerError() + } + + _, err = f.dbClient.UpdateResourceDoc(ctx, resourceDoc.Key, func(updateDoc *database.ResourceDocument) bool { + updateDoc.ActiveOperationID = operationDoc.ID + updateDoc.ProvisioningState = operationDoc.Status + return true + }) + if err != nil { + f.logger.Error(err.Error()) + return "", arm.NewInternalServerError() + } + + iterator := f.dbClient.ListResourceDocs(ctx, resourceDoc.Key, -1, nil) + + for item := range iterator.Items(ctx) { + // Anonymous function avoids repetitive error handling. + err = func() error { + var child database.ResourceDocument + + err = json.Unmarshal(item, &child) + if err != nil { + return err + } + + err = f.CancelActiveOperation(ctx, &child) + if err != nil { + return err + } + + // This operation is not accessible through any REST endpoint. + // Its purpose is to cause the backend to delete the resource + // document once resource deletion completes. + + childOperationDoc := database.NewOperationDocument(operationRequest, child.Key, child.InternalID) + + err = f.dbClient.CreateOperationDoc(ctx, childOperationDoc) + if err != nil { + return err + } + + _, err = f.dbClient.UpdateResourceDoc(ctx, child.Key, func(updateDoc *database.ResourceDocument) bool { + updateDoc.ActiveOperationID = childOperationDoc.ID + updateDoc.ProvisioningState = childOperationDoc.Status + return true + }) + if err != nil { + return err + } + + return nil + }() + if err != nil { + f.logger.Error(err.Error()) + return "", arm.NewInternalServerError() + } + } + + err = iterator.GetError() + if err != nil { + f.logger.Error(err.Error()) + return "", arm.NewInternalServerError() } - return doc, nil + return operationDoc.ID, nil } func (f *Frontend) MarshalResource(ctx context.Context, resourceID *arm.ResourceID, versionedInterface api.Version) ([]byte, *arm.CloudError) { diff --git a/frontend/pkg/frontend/node_pool.go b/frontend/pkg/frontend/node_pool.go index 615e72e6b..fe08e6c39 100644 --- a/frontend/pkg/frontend/node_pool.go +++ b/frontend/pkg/frontend/node_pool.go @@ -193,9 +193,18 @@ func (f *Frontend) CreateOrUpdateNodePool(writer http.ResponseWriter, request *h } } - operationDoc, err := f.StartOperation(writer, request, doc, operationRequest) + operationDoc := database.NewOperationDocument(operationRequest, doc.Key, doc.InternalID) + + err = f.dbClient.CreateOperationDoc(ctx, operationDoc) + if err != nil { + f.logger.Error(err.Error()) + arm.WriteInternalServerError(writer) + return + } + + err = f.ExposeOperation(writer, request, operationDoc.ID) if err != nil { - f.logger.Error(fmt.Sprintf("failed to write operation document: %v", err)) + f.logger.Error(err.Error()) arm.WriteInternalServerError(writer) return } diff --git a/frontend/pkg/frontend/operations.go b/frontend/pkg/frontend/operations.go index 4678a5f3f..f31078167 100644 --- a/frontend/pkg/frontend/operations.go +++ b/frontend/pkg/frontend/operations.go @@ -4,6 +4,8 @@ package frontend // Licensed under the Apache License 2.0. import ( + "context" + "errors" "fmt" "net/http" "net/url" @@ -77,47 +79,70 @@ func (f *Frontend) AddLocationHeader(writer http.ResponseWriter, request *http.R writer.Header().Set("Location", u.String()) } -func (f *Frontend) StartOperation(writer http.ResponseWriter, request *http.Request, resourceDoc *database.ResourceDocument, operationRequest database.OperationRequest) (*database.OperationDocument, error) { +// ExposeOperation fully initiates a new asynchronous operation by enriching +// the operation database item and adding the necessary response headers. +func (f *Frontend) ExposeOperation(writer http.ResponseWriter, request *http.Request, operationID string) error { ctx := request.Context() - operationDoc := database.NewOperationDocument(operationRequest) - - operationID, err := arm.ParseResourceID(path.Join("/", - "subscriptions", resourceDoc.Key.SubscriptionID, - "providers", api.ProviderNamespace, - "locations", f.location, - api.OperationStatusResourceTypeName, operationDoc.ID)) - if err != nil { - return nil, err - } - - operationDoc.TenantID = request.Header.Get(arm.HeaderNameHomeTenantID) - operationDoc.ClientID = request.Header.Get(arm.HeaderNameClientObjectID) - operationDoc.ExternalID = resourceDoc.Key - operationDoc.InternalID = resourceDoc.InternalID - operationDoc.OperationID = operationID - operationDoc.NotificationURI = request.Header.Get(arm.HeaderNameAsyncNotificationURI) - - err = f.dbClient.CreateOperationDoc(ctx, operationDoc) + _, err := f.dbClient.UpdateOperationDoc(ctx, operationID, func(updateDoc *database.OperationDocument) bool { + // There is no way to propagate a parse error here but it should + // never fail since we are building a trusted resource ID string. + operationID, err := arm.ParseResourceID(path.Join("/", + "subscriptions", updateDoc.ExternalID.SubscriptionID, + "providers", api.ProviderNamespace, + "locations", f.location, + api.OperationStatusResourceTypeName, operationID)) + if err != nil { + f.logger.Error(err.Error()) + return false + } + + updateDoc.TenantID = request.Header.Get(arm.HeaderNameHomeTenantID) + updateDoc.ClientID = request.Header.Get(arm.HeaderNameClientObjectID) + updateDoc.OperationID = operationID + updateDoc.NotificationURI = request.Header.Get(arm.HeaderNameAsyncNotificationURI) + + // If ARM passed a notification URI, acknowledge it. + if updateDoc.NotificationURI != "" { + writer.Header().Set(arm.HeaderNameAsyncNotification, "Enabled") + } + + // Add callback header(s) based on the request method. + switch request.Method { + case http.MethodDelete, http.MethodPatch: + f.AddLocationHeader(writer, request, updateDoc) + fallthrough + case http.MethodPut: + f.AddAsyncOperationHeader(writer, request, updateDoc) + } + + return true + }) if err != nil { - return nil, err + // Delete any response headers that may have been added. + writer.Header().Del(arm.HeaderNameAsyncNotification) + writer.Header().Del(arm.HeaderNameAsyncOperation) + writer.Header().Del("Location") } - // If ARM passed a notification URI, acknowledge it. - if operationDoc.NotificationURI != "" { - writer.Header().Set(arm.HeaderNameAsyncNotification, "Enabled") - } + return err +} - // Add callback header(s) based on the request method. - switch request.Method { - case http.MethodDelete, http.MethodPatch: - f.AddLocationHeader(writer, request, operationDoc) - fallthrough - case http.MethodPut: - f.AddAsyncOperationHeader(writer, request, operationDoc) +// CancelActiveOperation marks the status of any active operation on the resource as canceled. +func (f *Frontend) CancelActiveOperation(ctx context.Context, resourceDoc *database.ResourceDocument) error { + if resourceDoc.ActiveOperationID != "" { + updated, err := f.dbClient.UpdateOperationDoc(ctx, resourceDoc.ActiveOperationID, func(updateDoc *database.OperationDocument) bool { + return updateDoc.UpdateStatus(arm.ProvisioningStateCanceled, nil) + }) + // Disregard "not found" errors; a missing operation is effectively canceled. + if err != nil && !errors.Is(err, database.ErrNotFound) { + return err + } + if updated { + f.logger.Info(fmt.Sprintf("Canceled operation '%s'", resourceDoc.ActiveOperationID)) + } } - - return operationDoc, nil + return nil } // OperationIsVisible returns true if the request is being called from the same @@ -129,18 +154,23 @@ func (f *Frontend) OperationIsVisible(request *http.Request, doc *database.Opera clientID := request.Header.Get(arm.HeaderNameClientObjectID) subscriptionID := request.PathValue(PathSegmentSubscriptionID) - if doc.TenantID != "" && !strings.EqualFold(tenantID, doc.TenantID) { - f.logger.Info(fmt.Sprintf("Unauthorized tenant '%s' in status request for operation '%s'", tenantID, doc.ID)) - visible = false - } - - if doc.ClientID != "" && !strings.EqualFold(clientID, doc.ClientID) { - f.logger.Info(fmt.Sprintf("Unauthorized client '%s' in status request for operation '%s'", clientID, doc.ID)) - visible = false - } - - if !strings.EqualFold(subscriptionID, doc.OperationID.SubscriptionID) { - f.logger.Info(fmt.Sprintf("Unauthorized subscription '%s' in status request for operation '%s'", subscriptionID, doc.ID)) + if doc.OperationID != nil { + if doc.TenantID != "" && !strings.EqualFold(tenantID, doc.TenantID) { + f.logger.Info(fmt.Sprintf("Unauthorized tenant '%s' in status request for operation '%s'", tenantID, doc.ID)) + visible = false + } + + if doc.ClientID != "" && !strings.EqualFold(clientID, doc.ClientID) { + f.logger.Info(fmt.Sprintf("Unauthorized client '%s' in status request for operation '%s'", clientID, doc.ID)) + visible = false + } + + if !strings.EqualFold(subscriptionID, doc.OperationID.SubscriptionID) { + f.logger.Info(fmt.Sprintf("Unauthorized subscription '%s' in status request for operation '%s'", subscriptionID, doc.ID)) + visible = false + } + } else { + f.logger.Info(fmt.Sprintf("Status request for implicit operation '%s'", doc.ID)) visible = false } diff --git a/internal/database/cache.go b/internal/database/cache.go index 63b6ca620..a3679e397 100644 --- a/internal/database/cache.go +++ b/internal/database/cache.go @@ -9,8 +9,6 @@ import ( "iter" "strings" - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/ARO-HCP/internal/api/arm" ) @@ -24,14 +22,14 @@ type Cache struct { subscription map[string]*SubscriptionDocument } -type operationCacheIterator struct { - operation map[string]*OperationDocument - err error +type cacheIterator struct { + docs []any + err error } -func (iter operationCacheIterator) Items(ctx context.Context) iter.Seq[[]byte] { +func (iter cacheIterator) Items(ctx context.Context) iter.Seq[[]byte] { return func(yield func([]byte) bool) { - for _, doc := range iter.operation { + for _, doc := range iter.docs { // Marshalling the document struct only to immediately unmarshal // it back to a document struct is a little silly but this is to // conform to the DBClientIterator interface. @@ -48,7 +46,11 @@ func (iter operationCacheIterator) Items(ctx context.Context) iter.Seq[[]byte] { } } -func (iter operationCacheIterator) GetError() error { +func (iter cacheIterator) GetContinuationToken() string { + return "" +} + +func (iter cacheIterator) GetError() error { return iter.err } @@ -108,21 +110,19 @@ func (c *Cache) DeleteResourceDoc(ctx context.Context, resourceID *arm.ResourceI return nil } -func (c *Cache) ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, resourceType *azcorearm.ResourceType, pageSizeHint int32, continuationToken *string) ([]*ResourceDocument, *string, error) { - var resourceList []*ResourceDocument +func (c *Cache) ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator { + var iterator cacheIterator // Make sure key prefix is lowercase. prefixString := strings.ToLower(prefix.String() + "/") for key, doc := range c.resource { if strings.HasPrefix(key, prefixString) { - if resourceType == nil || strings.EqualFold(resourceType.String(), doc.Key.ResourceType.String()) { - resourceList = append(resourceList, doc) - } + iterator.docs = append(iterator.docs, doc) } } - return resourceList, nil, nil + return iterator } func (c *Cache) GetOperationDoc(ctx context.Context, operationID string) (*OperationDocument, error) { @@ -164,7 +164,11 @@ func (c *Cache) DeleteOperationDoc(ctx context.Context, operationID string) erro } func (c *Cache) ListAllOperationDocs(ctx context.Context) DBClientIterator { - return operationCacheIterator{operation: c.operation} + var iterator cacheIterator + for _, doc := range c.operation { + iterator.docs = append(iterator.docs, doc) + } + return iterator } func (c *Cache) GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*SubscriptionDocument, error) { diff --git a/internal/database/database.go b/internal/database/database.go index 85b7c7348..f7575fdd0 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -13,7 +13,6 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/ARO-HCP/internal/api/arm" @@ -51,6 +50,7 @@ func isResponseError(err error, statusCode int) bool { type DBClientIterator interface { Items(ctx context.Context) iter.Seq[[]byte] + GetContinuationToken() string GetError() error } @@ -71,7 +71,7 @@ type DBClient interface { // DeleteResourceDoc deletes a ResourceDocument from the database given the resourceID // of a Microsoft.RedHatOpenShift/HcpOpenShiftClusters resource or NodePools child resource. DeleteResourceDoc(ctx context.Context, resourceID *arm.ResourceID) error - ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, resourceType *azcorearm.ResourceType, pageSizeHint int32, continuationToken *string) ([]*ResourceDocument, *string, error) + ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator GetOperationDoc(ctx context.Context, operationID string) (*OperationDocument, error) CreateOperationDoc(ctx context.Context, doc *OperationDocument) error @@ -269,13 +269,23 @@ func (d *CosmosDBClient) DeleteResourceDoc(ctx context.Context, resourceID *arm. return nil } -func (d *CosmosDBClient) ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, resourceType *azcorearm.ResourceType, pageSizeHint int32, continuationToken *string) ([]*ResourceDocument, *string, error) { +// ListResourceDocs searches for resource documents that match the given resource ID prefix. +// maxItems can limit the number of items returned at once. A negative value will cause the +// returned iterator to yield all matching items. A positive value will cause the returned +// iterator to include a continuation token if additional items are available. +func (d *CosmosDBClient) ListResourceDocs(ctx context.Context, prefix *arm.ResourceID, maxItems int32, continuationToken *string) DBClientIterator { // Make sure partition key is lowercase. pk := azcosmos.NewPartitionKeyString(strings.ToLower(prefix.SubscriptionID)) + // XXX The Cosmos DB REST API gives special meaning to -1 for "x-ms-max-item-count" + // but it's not clear if it treats all negative values equivalently. The Go SDK + // passes the PageSizeHint value as provided so normalize negative values to -1 + // to be safe. + maxItems = max(maxItems, -1) + query := "SELECT * FROM c WHERE STARTSWITH(c.key, @prefix, true)" opt := azcosmos.QueryOptions{ - PageSizeHint: pageSizeHint, + PageSizeHint: maxItems, ContinuationToken: continuationToken, QueryParameters: []azcosmos.QueryParameter{ { @@ -285,39 +295,13 @@ func (d *CosmosDBClient) ListResourceDocs(ctx context.Context, prefix *arm.Resou }, } - var response azcosmos.QueryItemsResponse - resourceDocs := make([]*ResourceDocument, 0, pageSizeHint) - - // Loop until we fill the pre-allocated resourceDocs slice, - // or until we run out of items from the resources container. - for opt.PageSizeHint > 0 { - var err error - - response, err = d.resources.NewQueryItemsPager(query, pk, &opt).NextPage(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to advance page while querying Resources container for items with a key prefix of '%s': %w", prefix, err) - } - - for _, item := range response.Items { - var doc ResourceDocument - err = json.Unmarshal(item, &doc) - if err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal item while querying Resources container for items with a key prefix of '%s': %w", prefix, err) - } - if resourceType == nil || strings.EqualFold(resourceType.String(), doc.Key.ResourceType.String()) { - resourceDocs = append(resourceDocs, &doc) - } - } - - if response.ContinuationToken == nil { - break - } + pager := d.resources.NewQueryItemsPager(query, pk, &opt) - opt.PageSizeHint = int32(cap(resourceDocs) - len(resourceDocs)) - opt.ContinuationToken = response.ContinuationToken + if maxItems > 0 { + return NewQueryItemsSinglePageIterator(pager) + } else { + return NewQueryItemsIterator(pager) } - - return resourceDocs, response.ContinuationToken, nil } // GetOperationDoc retrieves the asynchronous operation document for the given diff --git a/internal/database/document.go b/internal/database/document.go index f40021592..e19d50082 100644 --- a/internal/database/document.go +++ b/internal/database/document.go @@ -77,7 +77,8 @@ type OperationDocument struct { ExternalID *arm.ResourceID `json:"externalId,omitempty"` // InternalID is the Cluster Service resource identifier in the form of a URL path InternalID ocm.InternalID `json:"internalId,omitempty"` - // OperationID is the Azure resource ID of the operation's status + // OperationID is the Azure resource ID of the operation status (may be nil if the + // operation was implicit, such as deleting a child resource along with the parent) OperationID *arm.ResourceID `json:"operationId,omitempty"` // NotificationURI is provided by the Azure-AsyncNotificationUri header if the // Async Operation Callbacks ARM feature is enabled @@ -94,13 +95,15 @@ type OperationDocument struct { Error *arm.CloudErrorBody `json:"error,omitempty"` } -func NewOperationDocument(request OperationRequest) *OperationDocument { +func NewOperationDocument(request OperationRequest, externalID *arm.ResourceID, internalID ocm.InternalID) *OperationDocument { now := time.Now().UTC() doc := &OperationDocument{ BaseDocument: newBaseDocument(), PartitionKey: operationsPartitionKey, Request: request, + ExternalID: externalID, + InternalID: internalID, StartTime: now, LastTransitionTime: now, Status: arm.ProvisioningStateAccepted, diff --git a/internal/database/util.go b/internal/database/util.go index 2ccbec4e4..df8abf006 100644 --- a/internal/database/util.go +++ b/internal/database/util.go @@ -12,8 +12,10 @@ import ( ) type QueryItemsIterator struct { - pager *runtime.Pager[azcosmos.QueryItemsResponse] - err error + pager *runtime.Pager[azcosmos.QueryItemsResponse] + singlePage bool + continuationToken string + err error } // NewQueryItemsIterator is a failable push iterator for a paged query response. @@ -21,6 +23,13 @@ func NewQueryItemsIterator(pager *runtime.Pager[azcosmos.QueryItemsResponse]) Qu return QueryItemsIterator{pager: pager} } +// NewQueryItemsSinglePageIterator is a failable push iterator for a paged +// query response that stops at the end of the first page and includes a +// continuation token if additional items are available. +func NewQueryItemsSinglePageIterator(pager *runtime.Pager[azcosmos.QueryItemsResponse]) QueryItemsIterator { + return QueryItemsIterator{pager: pager, singlePage: true} +} + // Items returns a push iterator that can be used directly in for/range loops. // If an error occurs during paging, iteration stops and the error is recorded. func (iter QueryItemsIterator) Items(ctx context.Context) iter.Seq[[]byte] { @@ -31,15 +40,28 @@ func (iter QueryItemsIterator) Items(ctx context.Context) iter.Seq[[]byte] { iter.err = err return } + if iter.singlePage && response.ContinuationToken != nil { + iter.continuationToken = *response.ContinuationToken + } for _, item := range response.Items { if !yield(item) { return } } + if iter.singlePage { + return + } } } } +// GetContinuationToken returns a continuation token that can be used to obtain +// the next page of results. This is only set when the iterator was created with +// NewQueryItemsSinglePageIterator and additional items are available. +func (iter QueryItemsIterator) GetContinuationToken() string { + return iter.continuationToken +} + // GetError returns any error that occurred during iteration. Call this after the // for/range loop that calls Items() to check if iteration completed successfully. func (iter QueryItemsIterator) GetError() error {