diff --git a/src/common/helpers.ts b/src/common/helpers.ts index f0a1bf1c6..6f9dd5c02 100644 --- a/src/common/helpers.ts +++ b/src/common/helpers.ts @@ -34,6 +34,7 @@ export function isSchemaResponse(resp?: Response): resp is SchemaResponse { export interface PagerInfo { name: string; op: Operation; + respField: boolean; } // returns true if the operation is pageable diff --git a/src/generator/helpers.ts b/src/generator/helpers.ts index 8a0ea40f1..2f0a3f7c0 100644 --- a/src/generator/helpers.ts +++ b/src/generator/helpers.ts @@ -4,8 +4,10 @@ *--------------------------------------------------------------------------------------------*/ import { Session } from '@azure-tools/autorest-extension-base'; -import { comment } from '@azure-tools/codegen'; -import { ArraySchema, CodeModel, DictionarySchema, Language, Parameter, Schema, SchemaType } from '@azure-tools/codemodel'; +import { values } from '@azure-tools/linq'; +import { comment, camelCase } from '@azure-tools/codegen'; +import { aggregateParameters } from '../common/helpers'; +import { ArraySchema, CodeModel, DictionarySchema, Language, Parameter, Schema, SchemaType, Operation, GroupProperty, ImplementationLocation } from '@azure-tools/codemodel'; // returns the common source-file preamble (license comment, package name etc) export async function contentPreamble(session: Session): Promise { @@ -73,3 +75,56 @@ export function substituteDiscriminator(schema: Schema): string { return schema.language.go!.name; } } + +// returns the parameters for the internal request creator method. +// e.g. "i int, s string" +export function getCreateRequestParametersSig(op: Operation): string { + const methodParams = getMethodParameters(op); + const params = new Array(); + for (const methodParam of values(methodParams)) { + params.push(`${camelCase(methodParam.language.go!.name)} ${formatParameterTypeName(methodParam)}`); + } + return params.join(', '); +} + +// returns the complete collection of method parameters +export function getMethodParameters(op: Operation): Parameter[] { + const params = new Array(); + const paramGroups = new Array(); + for (const param of values(aggregateParameters(op))) { + if (param.implementation === ImplementationLocation.Client) { + // client params are passed via the receiver + continue; + } else if (param.schema.type === SchemaType.Constant) { + // don't generate a parameter for a constant + continue; + } else if (param.language.go!.paramGroup) { + // param groups will be added after individual params + if (!paramGroups.includes(param.language.go!.paramGroup)) { + paramGroups.push(param.language.go!.paramGroup); + } + continue; + } + params.push(param); + } + // move global optional params to the end of the slice + params.sort(sortParametersByRequired); + // add any parameter groups. optional group goes last + paramGroups.sort((a: GroupProperty, b: GroupProperty) => { + if (a.required === b.required) { + return 0; + } + if (a.required && !b.required) { + return -1; + } + return 1; + }) + for (const paramGroup of values(paramGroups)) { + let name = camelCase(paramGroup.language.go!.name); + if (!paramGroup.required) { + name = 'options'; + } + params.push(paramGroup); + } + return params; +} diff --git a/src/generator/operations.ts b/src/generator/operations.ts index 450aeab26..9cd6f87bb 100644 --- a/src/generator/operations.ts +++ b/src/generator/operations.ts @@ -9,7 +9,7 @@ import { ArraySchema, ByteArraySchema, CodeModel, ConstantSchema, DateTimeSchema import { values } from '@azure-tools/linq'; import { aggregateParameters, isArraySchema, isPageableOperation, isSchemaResponse, PagerInfo, isLROOperation } from '../common/helpers'; import { OperationNaming } from '../transform/namer'; -import { contentPreamble, formatParameterTypeName, hasDescription, skipURLEncoding, sortAscending, sortParametersByRequired } from './helpers'; +import { contentPreamble, formatParameterTypeName, hasDescription, skipURLEncoding, sortAscending, sortParametersByRequired, getCreateRequestParametersSig, getMethodParameters } from './helpers'; import { ImportManager } from './imports'; const dateFormat = '2006-01-02'; @@ -292,14 +292,7 @@ function generateOperation(clientName: string, op: Operation, imports: ImportMan for (let i = 0; i < reqParams.length; ++i) { reqParams[i] = reqParams[i].trim().split(' ')[0]; } - // TODO Exception for Pageable LRO operations NYI if (isLROOperation(op)) { - // TODO remove LRO for pageable responses NYI - if (op.extensions!['x-ms-pageable']) { - text += `\treturn nil, nil\n`; - text += '}\n\n'; - return text; - } imports.add('time'); text += `\treq, err := client.${info.protocolNaming.requestMethod}(${reqParams.join(', ')})\n`; text += `\tif err != nil {\n`; @@ -332,15 +325,20 @@ function generateOperation(clientName: string, op: Operation, imports: ImportMan text += '\t}\n'; text += `\tpoller := &${camelCase(op.language.go!.pollerType.name)}{\n`; text += '\t\t\tpt: pt,\n'; + if (isPageableOperation(op)) { + text += `\t\t\trespHandler: client.${camelCase(op.language.go!.pageableType.name)}HandleResponse,\n`; + } text += '\t\t\tpipeline: client.p,\n'; text += '\t}\n'; text += '\tresult.Poller = poller\n'; - // http pollers will simply return an *http.Response - if (op.language.go!.pollerType.name === 'HTTPPoller') { - text += '\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration) (*http.Response, error) {\n'; - } else { - text += `\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration)(*${(op.responses![0]).schema.language.go!.responseType.name}, error) {\n`; + // determine the poller response based on the name and whether is is a pageable operation + let pollerResponse = '*http.Response'; + if (isPageableOperation(op)) { + pollerResponse = op.language.go!.pageableType.name; + } else if (isSchemaResponse(op.responses![0])) { + pollerResponse = '*' + (op.responses![0]).schema.language.go!.responseType.name; } + text += `\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration) (${pollerResponse}, error) {\n`; text += `\t\treturn poller.pollUntilDone(ctx, frequency)\n`; text += `\t}\n`; text += `\treturn result, nil\n`; @@ -653,9 +651,8 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa text += '}\n\n'; return text; } - const generateResponseUnmarshaller = function (response: Response): string { + const generateResponseUnmarshaller = function (response: Response, isLRO: boolean): string { let unmarshallerText = ''; - const isLRO = isLROOperation(op); if (!isSchemaResponse(response)) { if (isLRO) { unmarshallerText += '\treturn &HTTPPollerResponse{RawResponse: resp.Response}, nil\n'; @@ -698,8 +695,7 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa return unmarshallerText; } const schemaResponse = response; - // TODO remove paging skip when adding lro + pagers - if (isLRO && !isPageableOperation(op)) { + if (isLRO) { unmarshallerText += `\treturn &${schemaResponse.schema.language.go!.lroResponseType.language.go!.name}{RawResponse: resp.Response}, nil\n`; return unmarshallerText; } @@ -739,12 +735,28 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa text += `\tif !resp.HasStatusCode(${formatStatusCodes(statusCodes)}) {\n`; text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`; text += '\t}\n'; - text += generateResponseUnmarshaller(op.responses![0]); + if (isLROOperation(op) && isPageableOperation(op)) { + text += generateResponseUnmarshaller(op.responses![0], true); + text += '}\n\n'; + text += `${comment(name, '// ')} handles the ${info.name} response.\n`; + text += `func (client *${client}) ${camelCase(op.language.go!.pageableType.name)}HandleResponse(resp *azcore.Response) (*${(op.responses![0]).schema.language.go!.responseType.value}Response, error) {\n`; + const index = statusCodes.indexOf('204'); + if (index > -1) { + statusCodes.splice(index, 1); + } + statusCodes.push('200'); + text += `\tif !resp.HasStatusCode(${formatStatusCodes(statusCodes)}) {\n`; + text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`; + text += '\t}\n'; + text += generateResponseUnmarshaller(op.responses![0], false); + } else { + text += generateResponseUnmarshaller(op.responses![0], isLROOperation(op)); + } } else { text += '\tswitch resp.StatusCode {\n'; for (const response of values(op.responses)) { text += `\tcase ${formatStatusCodes(response.protocol.http!.statusCodes)}:\n` - text += generateResponseUnmarshaller(response); + text += generateResponseUnmarshaller(response, isLROOperation(op)); } text += '\tdefault:\n'; text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`; @@ -876,7 +888,7 @@ function createInterfaceDefinition(group: OperationGroup, imports: ImportManager const returns = generateReturnsInfo(op, false); interfaceText += `\t${opName}(${getAPIParametersSig(op, imports)}) (${returns.join(', ')})\n`; // Add resume LRO poller method for each Begin poller method - if (isLROOperation(op) && !op.extensions!['x-ms-pageable']) { + if (isLROOperation(op)) { interfaceText += `\t// Resume${op.language.go!.name} - Used to create a new instance of this poller from the resume token of a previous instance of this poller type.\n`; interfaceText += `\tResume${op.language.go!.name}(token string) (${op.language.go!.pollerType.name}, error)\n`; } @@ -971,7 +983,7 @@ function hasBinaryResponse(responses: Response[]): boolean { function getAPIParametersSig(op: Operation, imports: ImportManager): string { const methodParams = getMethodParameters(op); const params = new Array(); - if (!isPageableOperation(op)) { + if (!isPageableOperation(op) || isLROOperation(op)) { imports.add('context'); params.push('ctx context.Context'); } @@ -981,59 +993,6 @@ function getAPIParametersSig(op: Operation, imports: ImportManager): string { return params.join(', '); } -// returns the parameters for the internal request creator method. -// e.g. "i int, s string" -function getCreateRequestParametersSig(op: Operation): string { - const methodParams = getMethodParameters(op); - const params = new Array(); - for (const methodParam of values(methodParams)) { - params.push(`${camelCase(methodParam.language.go!.name)} ${formatParameterTypeName(methodParam)}`); - } - return params.join(', '); -} - -// returns the complete collection of method parameters -function getMethodParameters(op: Operation): Parameter[] { - const params = new Array(); - const paramGroups = new Array(); - for (const param of values(aggregateParameters(op))) { - if (param.implementation === ImplementationLocation.Client) { - // client params are passed via the receiver - continue; - } else if (param.schema.type === SchemaType.Constant) { - // don't generate a parameter for a constant - continue; - } else if (param.language.go!.paramGroup) { - // param groups will be added after individual params - if (!paramGroups.includes(param.language.go!.paramGroup)) { - paramGroups.push(param.language.go!.paramGroup); - } - continue; - } - params.push(param); - } - // move global optional params to the end of the slice - params.sort(sortParametersByRequired); - // add any parameter groups. optional group goes last - paramGroups.sort((a: GroupProperty, b: GroupProperty) => { - if (a.required === b.required) { - return 0; - } - if (a.required && !b.required) { - return -1; - } - return 1; - }) - for (const paramGroup of values(paramGroups)) { - let name = camelCase(paramGroup.language.go!.name); - if (!paramGroup.required) { - name = 'options'; - } - params.push(paramGroup); - } - return params; -} - // returns the return signature where each entry is the type name // e.g. [ '*string', 'error' ] function generateReturnsInfo(op: Operation, forHandler: boolean): string[] { @@ -1052,8 +1011,7 @@ function generateReturnsInfo(op: Operation, forHandler: boolean): string[] { returnType = op.language.go!.pageableType.name; } else if (isSchemaResponse(firstResp)) { returnType = '*' + firstResp.schema.language.go!.responseType.name; - // TODO remove paging skip when adding LRO + pagers - if (isLROOperation(op) && !isPageableOperation(op)) { + if (isLROOperation(op)) { returnType = '*' + firstResp.schema.language.go!.lroResponseType.language.go!.name; } } else if (isLROOperation(op)) { diff --git a/src/generator/pagers.ts b/src/generator/pagers.ts index c4910d766..7f587a296 100644 --- a/src/generator/pagers.ts +++ b/src/generator/pagers.ts @@ -32,6 +32,21 @@ export async function generatePagers(session: Session): Promiseop.responses![0]; + let text = 'if p.pt.pollerMethodVerb() == http.MethodPut || p.pt.pollerMethodVerb() == http.MethodPatch {'; + if (!isPageableOperation(op)) { + text += ` + res, err := p.handleResponse(p.pt.latestResponse()) + if err != nil { + return nil, err + } + `; + switch (respSchema.schema.type) { + case SchemaType.Array: + case SchemaType.Dictionary: + text += `if res != nil && res.${respSchema.schema.language.go!.responseType.value} != nil {`; + break; + case SchemaType.String: + text += `if res != nil && (*res.${respSchema.schema.language.go!.responseType.value} != "") {`; + break; + default: + text += `if res != nil && (*res.${respSchema.schema.language.go!.responseType.value} != ${respSchema.schema.language.go!.responseType.value}{}) {`; + } + text += ` return res, nil + }`; + } else { + text += 'return p.handleResponse(p.pt.latestResponse())'; + } + text += '}'; + return text; +} + +function generatePagerReturnInstance(op: Operation, imports: ImportManager): string { + let text = ''; + const info = op.language.go!; + // split param list into individual params + const reqParams = getCreateRequestParametersSig(op).split(','); + // keep the parameter names from the name/type tuples + for (let i = 0; i < reqParams.length; ++i) { + reqParams[i] = reqParams[i].trim().split(' ')[0]; + } + text += `\treturn &${camelCase(op.language.go!.pageableType.name)}{\n`; + text += `\t\tpipeline: p.pipeline,\n`; + text += `\t\tresp: resp,\n`; + text += `\t\tresponder: p.respHandler,\n`; + const pager = op.language.go!.pageableType; + const pagerSchema = pager.op.responses![0]; + if (op.language.go!.paging.member) { + // find the location of the nextLink param + const nextLinkOpParams = getMethodParameters(op.language.go!.paging.nextLinkOperation); + let found = false; + for (let i = 0; i < nextLinkOpParams.length; ++i) { + if (nextLinkOpParams[i].schema.type === SchemaType.String && nextLinkOpParams[i].language.go!.name.startsWith('next')) { + // found it + reqParams.splice(i, 0, `*resp.${pagerSchema.schema.language.go!.name}.${pager.op.language.go!.paging.nextLinkName}`); + found = true; + break; } } - `; + if (!found) { + throw console.error(`failed to find nextLink parameter for operation ${op.language.go!.paging.nextLinkOperation.language.go!.name}`); + } + text += `\t\tadvancer: func(resp *${pagerSchema.schema.language.go!.responseType.name}) (*azcore.Request, error) {\n`; + text += `\t\t\treturn client.${camelCase(op.language.go!.paging.member)}CreateRequest(${reqParams.join(', ')})\n`; + text += '\t\t},\n'; + } else { + imports.add('fmt'); + imports.add('net/url'); + let resultTypeName = pagerSchema.schema.language.go!.name; + if (pagerSchema.schema.serialization?.xml?.name) { + // xml can specifiy its own name, prefer that if available + resultTypeName = pagerSchema.schema.serialization.xml.name; + } + text += `\t\tadvancer: func(resp *${pagerSchema.schema.language.go!.responseType.name}) (*azcore.Request, error) {\n`; + text += `\t\t\tu, err := url.Parse(*resp.${resultTypeName}.${pager.op.language.go!.paging.nextLinkName})\n`; + text += `\t\t\tif err != nil {\n`; + text += `\t\t\t\treturn nil, fmt.Errorf("invalid ${pager.op.language.go!.paging.nextLinkName}: %w", err)\n`; + text += `\t\t\t}\n`; + text += `\t\t\tif u.Scheme == "" {\n`; + text += `\t\t\t\treturn nil, fmt.Errorf("no scheme detected in ${pager.op.language.go!.paging.nextLinkName} %s", *resp.${resultTypeName}.${pager.op.language.go!.paging.nextLinkName})\n`; + text += `\t\t\t}\n`; + text += `\t\t\treturn azcore.NewRequest(http.MethodGet, *u), nil\n`; + text += `\t\t},\n`; + } + text += `\t}, nil`; return text; } @@ -70,18 +132,52 @@ export async function generatePollers(session: Session): Promisepoller.op.responses![0]; let unmarshalResponse = 'nil'; - if (isSchemaResponse(schemaResponse) && schemaResponse.schema.language.go!.responseType.name !== undefined) { + let pagerFields = ''; + let finalResponseCheckNeeded = false; + if (isPageableOperation(poller.op)) { + responseType = poller.op.language.go!.pageableType.name; + pollUntilDoneResponse = `(${responseType}, error)`; + pollUntilDoneReturn = 'p.FinalResponse(ctx)'; + // for operations that do return a model add a final response method that handles the final get URL scenario + finalResponseDeclaration = `FinalResponse(ctx context.Context) (${responseType}, error)`; + pagerFields = ` + respHandler ${camelCase(poller.op.language.go!.pageableType.op.responses![0].schema.language.go!.name)}HandleResponse`; + handleResponse = ` + func (p *${pollerName}) handleResponse(resp *azcore.Response) (${responseType}, error) { + ${generatePagerReturnInstance(poller.op, imports)} + } + `; + finalResponse = `${finalResponseDeclaration} {`; + finalResponseCheckNeeded = true; + } else if (isSchemaResponse(schemaResponse) && schemaResponse.schema.language.go!.responseType.name !== undefined) { responseType = schemaResponse.schema.language.go!.responseType.name; pollUntilDoneResponse = `(*${responseType}, error)`; pollUntilDoneReturn = 'p.FinalResponse(ctx)'; unmarshalResponse = `resp.UnmarshalAsJSON(&result.${schemaResponse.schema.language.go!.responseType.value})`; // for operations that do return a model add a final response method that handles the final get URL scenario finalResponseDeclaration = `FinalResponse(ctx context.Context) (*${responseType}, error)`; - finalResponse = `FinalResponse(ctx context.Context) (*${responseType}, error) { + handleResponse = ` + func (p *${pollerName}) handleResponse(resp *azcore.Response) (*${responseType}, error) { + result := ${responseType}{RawResponse: resp.Response} + if resp.HasStatusCode(http.StatusNoContent) { + return &result, nil + } + if !resp.HasStatusCode(pollingCodes[:]...) { + return nil, p.pt.handleError(resp) + } + return &result, ${unmarshalResponse} + } + `; + finalResponse = `FinalResponse(ctx context.Context) (*${responseType}, error) {`; + finalResponseCheckNeeded = true; + } + if (finalResponseCheckNeeded) { + finalResponse += ` if !p.Done() { return nil, errors.New("cannot return a final response from a poller in a non-terminal state") } - ${getPutCheck(schemaResponse)}// checking if there was a FinalStateVia configuration to re-route the final GET + ${getPutCheck(poller.op)} + // checking if there was a FinalStateVia configuration to re-route the final GET // request to the value specified in the FinalStateVia property on the poller err := p.pt.setFinalState() if err != nil { @@ -114,18 +210,6 @@ export async function generatePollers(session: Session): Promise): Promise>codeModel.language.go!.pageableTypes; for (const pager of values(pagers)) { if (pager.name === name) { + // this LRO check is necessary for operations that synchronously and asynchronously return a pager + // this will ensure that pagers that are used with pollers will have the response field included + if (isLROOperation(op)) { + pager.respField = true; + } // found a match, hook it up to the method op.language.go!.pageableType = pager; skipAddPager = true; @@ -584,6 +585,7 @@ function createResponseType(codeModel: CodeModel, group: OperationGroup, op: Ope const pager = { name: name, op: op, + respField: isLROOperation(op), }; pagers.push(pager); op.language.go!.pageableType = pager; @@ -591,15 +593,14 @@ function createResponseType(codeModel: CodeModel, group: OperationGroup, op: Ope } // create poller type info if (isLROOperation(op)) { + // create the poller response envelope + generateLROResponseType(response, op, codeModel); if (codeModel.language.go!.pollerTypes === undefined) { codeModel.language.go!.pollerTypes = new Array(); } // Determine the type of poller that needs to be added based on whether a schema is specified in the response // if there is no schema specified for the operation response then a simple HTTP poller will be instantiated - let name = 'HTTPPoller'; - if (isSchemaResponse(response) && response.schema.language.go!.responseType.value) { - name = generateLROPollerName(response); - } + const name = generateLROPollerName(response, op); const pollers = >codeModel.language.go!.pollerTypes; let skipAddLRO = false; for (const poller of values(pollers)) { @@ -743,11 +744,14 @@ function generateResponseTypeName(schema: Schema): Language { // generate LRO response type name is separate from the general response type name // generation, since it requires returning the poller response envelope -function generateLROResponseTypeName(response: Response): Language { +function generateLROResponseTypeName(response: Response, op: Operation): Language { // default to generic response envelope - let name = 'HTTPPollerResponse' + let name = 'HTTPPollerResponse'; let desc = `${name} contains the asynchronous HTTP response from the call to the service endpoint.`; - if (isSchemaResponse(response)) { + if (isPageableOperation(op)) { + name = `${op.language.go!.pageableType.name}PollerResponse`; + desc = `${name} is the response envelope for operations that asynchronously return a ${op.language.go!.pageableType.name} type.`; + } else if (isSchemaResponse(response)) { // create a type-specific response envelope const typeName = recursiveTypeName(response.schema) + 'Poller'; name = `${typeName}Response`; @@ -760,7 +764,14 @@ function generateLROResponseTypeName(response: Response): Language { }; } -function generateLROPollerName(schemaResp: SchemaResponse): string { +function generateLROPollerName(response: Response, op: Operation): string { + if (!isSchemaResponse(response)) { + return 'HTTPPoller'; + } + const schemaResp = response; + if (isPageableOperation(op)) { + return `${op.language.go!.pageableType.name}Poller`; + } if (schemaResp.schema.language.go!.responseType.value === scalarResponsePropName) { // for scalar responses, use the underlying type name for the poller return `${pascalCase(schemaResp.schema.language.go!.name)}Poller`; @@ -769,7 +780,7 @@ function generateLROPollerName(schemaResp: SchemaResponse): string { } function generateLROResponseType(response: Response, op: Operation, codeModel: CodeModel) { - const respTypeName = generateLROResponseTypeName(response); + const respTypeName = generateLROResponseTypeName(response, op); if (responseExists(codeModel, respTypeName.name)) { return; } @@ -793,7 +804,7 @@ function generateLROResponseType(response: Response, op: Operation, codeModel: C response.schema.language.go!.lroResponseType = respTypeObject; } else { pollerResponse = `*${response.schema.language.go!.responseType.name}`; - pollerTypeName = generateLROPollerName(response); + pollerTypeName = generateLROPollerName(response, op); response.schema.language.go!.isLRO = true; response.schema.language.go!.lroResponseType = respTypeObject; } diff --git a/test/autorest/generated/paginggroup/models.go b/test/autorest/generated/paginggroup/models.go index 10a8b3cf3..b87da1717 100644 --- a/test/autorest/generated/paginggroup/models.go +++ b/test/autorest/generated/paginggroup/models.go @@ -5,7 +5,11 @@ package paginggroup -import "net/http" +import ( + "context" + "net/http" + "time" +) // CustomParameterGroup contains a group of parameters for the Paging.GetMultiplePagesFragmentWithGroupingNextLink method. type CustomParameterGroup struct { @@ -85,6 +89,19 @@ type ProductResult struct { Values *[]Product `json:"values,omitempty"` } +// ProductResultPagerPollerResponse is the response envelope for operations that asynchronously return a ProductResultPager +// type. +type ProductResultPagerPollerResponse struct { + // PollUntilDone will poll the service endpoint until a terminal state is reached or an error is received + PollUntilDone func(ctx context.Context, frequency time.Duration) (ProductResultPager, error) + + // Poller contains an initialized poller. + Poller ProductResultPagerPoller + + // RawResponse contains the underlying HTTP response. + RawResponse *http.Response +} + // ProductResultResponse is the response envelope for operations that return a ProductResult type. type ProductResultResponse struct { ProductResult *ProductResult diff --git a/test/autorest/generated/paginggroup/pagers.go b/test/autorest/generated/paginggroup/pagers.go index 5119e9461..afa56463b 100644 --- a/test/autorest/generated/paginggroup/pagers.go +++ b/test/autorest/generated/paginggroup/pagers.go @@ -106,6 +106,8 @@ type productResultPager struct { current *ProductResultResponse // any error encountered err error + // previous response from the endpoint + resp *azcore.Response } func (p *productResultPager) Err() error { @@ -124,7 +126,13 @@ func (p *productResultPager) NextPage(ctx context.Context) bool { } p.request = req } - resp, err := p.pipeline.Do(ctx, p.request) + resp := p.resp + var err error + if resp == nil { + resp, err = p.pipeline.Do(ctx, p.request) + } else { + p.resp = nil + } if err != nil { p.err = err return false diff --git a/test/autorest/generated/paginggroup/paging.go b/test/autorest/generated/paginggroup/paging.go index 9cb5f745e..46524cafe 100644 --- a/test/autorest/generated/paginggroup/paging.go +++ b/test/autorest/generated/paginggroup/paging.go @@ -15,6 +15,7 @@ import ( "net/url" "strconv" "strings" + "time" ) // PagingOperations contains the methods for the Paging group. @@ -30,7 +31,9 @@ type PagingOperations interface { // GetMultiplePagesFragmentWithGroupingNextLink - A paging operation that doesn't return a full URL, just a fragment with parameters grouped GetMultiplePagesFragmentWithGroupingNextLink(customParameterGroup CustomParameterGroup) (OdataProductResultPager, error) // BeginGetMultiplePagesLro - A long-running paging operation that includes a nextLink that has 10 pages - BeginGetMultiplePagesLro(pagingGetMultiplePagesLroOptions *PagingGetMultiplePagesLroOptions) (*ProductResultResponse, error) + BeginGetMultiplePagesLro(ctx context.Context, pagingGetMultiplePagesLroOptions *PagingGetMultiplePagesLroOptions) (*ProductResultPagerPollerResponse, error) + // ResumeGetMultiplePagesLro - Used to create a new instance of this poller from the resume token of a previous instance of this poller type. + ResumeGetMultiplePagesLro(token string) (ProductResultPagerPoller, error) // GetMultiplePagesRetryFirst - A paging operation that fails on the first call with 500 and then retries and then get a response including a nextLink that has 10 pages GetMultiplePagesRetryFirst() (ProductResultPager, error) // GetMultiplePagesRetrySecond - A paging operation that includes a nextLink that has 10 pages, of which the 2nd call fails first with 500. The client should retry and finish all 10 pages eventually. @@ -337,8 +340,45 @@ func (client *pagingOperations) getMultiplePagesFragmentWithGroupingNextLinkHand } // GetMultiplePagesLro - A long-running paging operation that includes a nextLink that has 10 pages -func (client *pagingOperations) BeginGetMultiplePagesLro(pagingGetMultiplePagesLroOptions *PagingGetMultiplePagesLroOptions) (*ProductResultResponse, error) { - return nil, nil +func (client *pagingOperations) BeginGetMultiplePagesLro(ctx context.Context, pagingGetMultiplePagesLroOptions *PagingGetMultiplePagesLroOptions) (*ProductResultPagerPollerResponse, error) { + req, err := client.getMultiplePagesLroCreateRequest(pagingGetMultiplePagesLroOptions) + if err != nil { + return nil, err + } + // send the first request to initialize the poller + resp, err := client.p.Do(ctx, req) + if err != nil { + return nil, err + } + result, err := client.getMultiplePagesLroHandleResponse(resp) + if err != nil { + return nil, err + } + pt, err := createPollingTracker("pagingOperations.GetMultiplePagesLro", "", resp, client.getMultiplePagesLroHandleError) + if err != nil { + return nil, err + } + poller := &productResultPagerPoller{ + pt: pt, + respHandler: client.productResultPagerHandleResponse, + pipeline: client.p, + } + result.Poller = poller + result.PollUntilDone = func(ctx context.Context, frequency time.Duration) (ProductResultPager, error) { + return poller.pollUntilDone(ctx, frequency) + } + return result, nil +} + +func (client *pagingOperations) ResumeGetMultiplePagesLro(token string) (ProductResultPagerPoller, error) { + pt, err := resumePollingTracker("pagingOperations.GetMultiplePagesLro", token, client.getMultiplePagesLroHandleError) + if err != nil { + return nil, err + } + return &productResultPagerPoller{ + pipeline: client.p, + pt: pt, + }, nil } // getMultiplePagesLroCreateRequest creates the GetMultiplePagesLro request. @@ -362,10 +402,18 @@ func (client *pagingOperations) getMultiplePagesLroCreateRequest(pagingGetMultip } // getMultiplePagesLroHandleResponse handles the GetMultiplePagesLro response. -func (client *pagingOperations) getMultiplePagesLroHandleResponse(resp *azcore.Response) (*ProductResultResponse, error) { +func (client *pagingOperations) getMultiplePagesLroHandleResponse(resp *azcore.Response) (*ProductResultPagerPollerResponse, error) { if !resp.HasStatusCode(http.StatusAccepted, http.StatusNoContent) { return nil, client.getMultiplePagesLroHandleError(resp) } + return &ProductResultPagerPollerResponse{RawResponse: resp.Response}, nil +} + +// getMultiplePagesLroHandleResponse handles the GetMultiplePagesLro response. +func (client *pagingOperations) productResultPagerHandleResponse(resp *azcore.Response) (*ProductResultResponse, error) { + if !resp.HasStatusCode(http.StatusAccepted, http.StatusOK) { + return nil, client.getMultiplePagesLroHandleError(resp) + } result := ProductResultResponse{RawResponse: resp.Response} return &result, resp.UnmarshalAsJSON(&result.ProductResult) } diff --git a/test/autorest/generated/paginggroup/pollers.go b/test/autorest/generated/paginggroup/pollers.go index 6aed59307..2bbfa8b9c 100644 --- a/test/autorest/generated/paginggroup/pollers.go +++ b/test/autorest/generated/paginggroup/pollers.go @@ -9,51 +9,47 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "net/http" "net/url" "time" ) -// ProductResultPoller provides polling facilities until the operation completes -type ProductResultPoller interface { +// ProductResultPagerPoller provides polling facilities until the operation completes +type ProductResultPagerPoller interface { Done() bool Poll(ctx context.Context) (*http.Response, error) - FinalResponse(ctx context.Context) (*ProductResultResponse, error) + FinalResponse(ctx context.Context) (ProductResultPager, error) ResumeToken() (string, error) } -type productResultPoller struct { +type productResultPagerPoller struct { // the client for making the request - pipeline azcore.Pipeline - pt pollingTracker + pipeline azcore.Pipeline + respHandler productResultHandleResponse + pt pollingTracker } // Done returns true if there was an error or polling has reached a terminal state -func (p *productResultPoller) Done() bool { +func (p *productResultPagerPoller) Done() bool { return p.pt.hasTerminated() } // Poll will send poll the service endpoint and return an http.Response or error received from the service -func (p *productResultPoller) Poll(ctx context.Context) (*http.Response, error) { +func (p *productResultPagerPoller) Poll(ctx context.Context) (*http.Response, error) { if lroPollDone(ctx, p.pipeline, p.pt) { return p.pt.latestResponse().Response, p.pt.pollingError() } return nil, p.pt.pollingError() } -func (p *productResultPoller) FinalResponse(ctx context.Context) (*ProductResultResponse, error) { +func (p *productResultPagerPoller) FinalResponse(ctx context.Context) (ProductResultPager, error) { if !p.Done() { return nil, errors.New("cannot return a final response from a poller in a non-terminal state") } if p.pt.pollerMethodVerb() == http.MethodPut || p.pt.pollerMethodVerb() == http.MethodPatch { - res, err := p.handleResponse(p.pt.latestResponse()) - if err != nil { - return nil, err - } - if res != nil && (*res.ProductResult != ProductResult{}) { - return res, nil - } + return p.handleResponse(p.pt.latestResponse()) } // checking if there was a FinalStateVia configuration to re-route the final GET // request to the value specified in the FinalStateVia property on the poller @@ -89,9 +85,9 @@ func (p *productResultPoller) FinalResponse(ctx context.Context) (*ProductResult return p.handleResponse(resp) } -// ResumeToken generates the string token that can be used with the ResumeProductResultPoller method +// ResumeToken generates the string token that can be used with the ResumeProductResultPagerPoller method // on the client to create a new poller from the data held in the current poller type -func (p *productResultPoller) ResumeToken() (string, error) { +func (p *productResultPagerPoller) ResumeToken() (string, error) { if p.pt.hasTerminated() { return "", errors.New("cannot create a ResumeToken from a poller in a terminal state") } @@ -102,7 +98,7 @@ func (p *productResultPoller) ResumeToken() (string, error) { return string(js), nil } -func (p *productResultPoller) pollUntilDone(ctx context.Context, frequency time.Duration) (*ProductResultResponse, error) { +func (p *productResultPagerPoller) pollUntilDone(ctx context.Context, frequency time.Duration) (ProductResultPager, error) { for { resp, err := p.Poll(ctx) if err != nil { @@ -120,13 +116,20 @@ func (p *productResultPoller) pollUntilDone(ctx context.Context, frequency time. return p.FinalResponse(ctx) } -func (p *productResultPoller) handleResponse(resp *azcore.Response) (*ProductResultResponse, error) { - result := ProductResultResponse{RawResponse: resp.Response} - if resp.HasStatusCode(http.StatusNoContent) { - return &result, nil - } - if !resp.HasStatusCode(pollingCodes[:]...) { - return nil, p.pt.handleError(resp) - } - return &result, resp.UnmarshalAsJSON(&result.ProductResult) +func (p *productResultPagerPoller) handleResponse(resp *azcore.Response) (ProductResultPager, error) { + return &productResultPager{ + pipeline: p.pipeline, + resp: resp, + responder: p.respHandler, + advancer: func(resp *ProductResultResponse) (*azcore.Request, error) { + u, err := url.Parse(*resp.ProductResult.NextLink) + if err != nil { + return nil, fmt.Errorf("invalid NextLink: %w", err) + } + if u.Scheme == "" { + return nil, fmt.Errorf("no scheme detected in NextLink %s", *resp.ProductResult.NextLink) + } + return azcore.NewRequest(http.MethodGet, *u), nil + }, + }, nil } diff --git a/test/autorest/paginggroup/paginggroup_test.go b/test/autorest/paginggroup/paginggroup_test.go index f8a1a098d..c85a9bff7 100644 --- a/test/autorest/paginggroup/paginggroup_test.go +++ b/test/autorest/paginggroup/paginggroup_test.go @@ -172,7 +172,39 @@ func TestGetMultiplePagesFragmentWithGroupingNextLink(t *testing.T) { // GetMultiplePagesLro - A long-running paging operation that includes a nextLink that has 10 pages func TestGetMultiplePagesLro(t *testing.T) { - t.Skip("LRO NYI") + client := getPagingOperations(t) + resp, err := client.BeginGetMultiplePagesLro(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + poller := resp.Poller + rt, err := poller.ResumeToken() + if err != nil { + t.Fatal(err) + } + poller, err = client.ResumeGetMultiplePagesLro(rt) + if err != nil { + t.Fatal(err) + } + pager, err := resp.PollUntilDone(context.Background(), 1*time.Millisecond) + if err != nil { + t.Fatal(err) + } + count := 0 + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if len(*resp.ProductResult.Values) == 0 { + t.Fatal("missing payload") + } + count++ + } + if err = pager.Err(); err != nil { + t.Fatal(err) + } + const pageCount = 10 + if count != pageCount { + helpers.DeepEqualOrFatal(t, count, pageCount) + } } // GetMultiplePagesRetryFirst - A paging operation that fails on the first call with 500 and then retries and then get a response including a nextLink that has 10 pages diff --git a/test/storage/2019-07-07/azblob/enums.go b/test/storage/2019-07-07/azblob/enums.go index 04959cc86..4337c7179 100644 --- a/test/storage/2019-07-07/azblob/enums.go +++ b/test/storage/2019-07-07/azblob/enums.go @@ -50,9 +50,11 @@ func (c AccessTier) ToPtr() *AccessTier { type AccountKind string const ( - AccountKindStorage AccountKind = "Storage" - AccountKindBlobStorage AccountKind = "BlobStorage" - AccountKindStorageV2 AccountKind = "StorageV2" + AccountKindStorage AccountKind = "Storage" + AccountKindBlobStorage AccountKind = "BlobStorage" + AccountKindStorageV2 AccountKind = "StorageV2" + AccountKindFileStorage AccountKind = "FileStorage" + AccountKindBlockBlobStorage AccountKind = "BlockBlobStorage" ) func PossibleAccountKindValues() []AccountKind { @@ -60,6 +62,8 @@ func PossibleAccountKindValues() []AccountKind { AccountKindStorage, AccountKindBlobStorage, AccountKindStorageV2, + AccountKindFileStorage, + AccountKindBlockBlobStorage, } }