Skip to content

Commit

Permalink
LRO+pager work (#427)
Browse files Browse the repository at this point in the history
* Adding poller response envelope transformations and generation

* Initial LRO+pager work

* Adding pager poller response envelope corrections

* Adding pager response type to begin lro method

* Updating PollUntilDone pointer

* Adding handler for inner pager result value

* Updating pager test

* simplifying code

* Updating poller pager handle response url

* Adding page count to lro pager test

* Removing extra spacing

* Adding azcore.Response to pager in order to not repeat get in LRO case

* bug fix

* set response on pager poller with LRO response

clear pager poller response after consumption

* simplify put check for pagers

* consolidating lro response name in one place

* Improving pager resp field code and adding comment

* refactoring polling naming to one function

* code improvements

* refactor

Co-authored-by: Joel Hendrix <[email protected]>
Co-authored-by: Catalina Peralta <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2020
1 parent 7d05c7f commit 7f7a951
Show file tree
Hide file tree
Showing 12 changed files with 408 additions and 172 deletions.
1 change: 1 addition & 0 deletions src/common/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export function isSchemaResponse(resp?: Response): resp is SchemaResponse {
export interface PagerInfo {
name: string;
op: Operation;
respField: boolean;
}

// returns true if the operation is pageable
Expand Down
59 changes: 57 additions & 2 deletions src/generator/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
*--------------------------------------------------------------------------------------------*/

import { Session } from '@azure-tools/autorest-extension-base';
import { comment } from '@azure-tools/codegen';
import { ArraySchema, CodeModel, DictionarySchema, Language, Parameter, Schema, SchemaType } from '@azure-tools/codemodel';
import { values } from '@azure-tools/linq';
import { comment, camelCase } from '@azure-tools/codegen';
import { aggregateParameters } from '../common/helpers';
import { ArraySchema, CodeModel, DictionarySchema, Language, Parameter, Schema, SchemaType, Operation, GroupProperty, ImplementationLocation } from '@azure-tools/codemodel';

// returns the common source-file preamble (license comment, package name etc)
export async function contentPreamble(session: Session<CodeModel>): Promise<string> {
Expand Down Expand Up @@ -73,3 +75,56 @@ export function substituteDiscriminator(schema: Schema): string {
return schema.language.go!.name;
}
}

// returns the parameters for the internal request creator method.
// e.g. "i int, s string"
export function getCreateRequestParametersSig(op: Operation): string {
const methodParams = getMethodParameters(op);
const params = new Array<string>();
for (const methodParam of values(methodParams)) {
params.push(`${camelCase(methodParam.language.go!.name)} ${formatParameterTypeName(methodParam)}`);
}
return params.join(', ');
}

// returns the complete collection of method parameters
export function getMethodParameters(op: Operation): Parameter[] {
const params = new Array<Parameter>();
const paramGroups = new Array<GroupProperty>();
for (const param of values(aggregateParameters(op))) {
if (param.implementation === ImplementationLocation.Client) {
// client params are passed via the receiver
continue;
} else if (param.schema.type === SchemaType.Constant) {
// don't generate a parameter for a constant
continue;
} else if (param.language.go!.paramGroup) {
// param groups will be added after individual params
if (!paramGroups.includes(param.language.go!.paramGroup)) {
paramGroups.push(param.language.go!.paramGroup);
}
continue;
}
params.push(param);
}
// move global optional params to the end of the slice
params.sort(sortParametersByRequired);
// add any parameter groups. optional group goes last
paramGroups.sort((a: GroupProperty, b: GroupProperty) => {
if (a.required === b.required) {
return 0;
}
if (a.required && !b.required) {
return -1;
}
return 1;
})
for (const paramGroup of values(paramGroups)) {
let name = camelCase(paramGroup.language.go!.name);
if (!paramGroup.required) {
name = 'options';
}
params.push(paramGroup);
}
return params;
}
110 changes: 34 additions & 76 deletions src/generator/operations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { ArraySchema, ByteArraySchema, CodeModel, ConstantSchema, DateTimeSchema
import { values } from '@azure-tools/linq';
import { aggregateParameters, isArraySchema, isPageableOperation, isSchemaResponse, PagerInfo, isLROOperation } from '../common/helpers';
import { OperationNaming } from '../transform/namer';
import { contentPreamble, formatParameterTypeName, hasDescription, skipURLEncoding, sortAscending, sortParametersByRequired } from './helpers';
import { contentPreamble, formatParameterTypeName, hasDescription, skipURLEncoding, sortAscending, sortParametersByRequired, getCreateRequestParametersSig, getMethodParameters } from './helpers';
import { ImportManager } from './imports';

const dateFormat = '2006-01-02';
Expand Down Expand Up @@ -292,14 +292,7 @@ function generateOperation(clientName: string, op: Operation, imports: ImportMan
for (let i = 0; i < reqParams.length; ++i) {
reqParams[i] = reqParams[i].trim().split(' ')[0];
}
// TODO Exception for Pageable LRO operations NYI
if (isLROOperation(op)) {
// TODO remove LRO for pageable responses NYI
if (op.extensions!['x-ms-pageable']) {
text += `\treturn nil, nil\n`;
text += '}\n\n';
return text;
}
imports.add('time');
text += `\treq, err := client.${info.protocolNaming.requestMethod}(${reqParams.join(', ')})\n`;
text += `\tif err != nil {\n`;
Expand Down Expand Up @@ -332,15 +325,20 @@ function generateOperation(clientName: string, op: Operation, imports: ImportMan
text += '\t}\n';
text += `\tpoller := &${camelCase(op.language.go!.pollerType.name)}{\n`;
text += '\t\t\tpt: pt,\n';
if (isPageableOperation(op)) {
text += `\t\t\trespHandler: client.${camelCase(op.language.go!.pageableType.name)}HandleResponse,\n`;
}
text += '\t\t\tpipeline: client.p,\n';
text += '\t}\n';
text += '\tresult.Poller = poller\n';
// http pollers will simply return an *http.Response
if (op.language.go!.pollerType.name === 'HTTPPoller') {
text += '\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration) (*http.Response, error) {\n';
} else {
text += `\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration)(*${(<SchemaResponse>op.responses![0]).schema.language.go!.responseType.name}, error) {\n`;
// determine the poller response based on the name and whether is is a pageable operation
let pollerResponse = '*http.Response';
if (isPageableOperation(op)) {
pollerResponse = op.language.go!.pageableType.name;
} else if (isSchemaResponse(op.responses![0])) {
pollerResponse = '*' + (<SchemaResponse>op.responses![0]).schema.language.go!.responseType.name;
}
text += `\tresult.PollUntilDone = func(ctx context.Context, frequency time.Duration) (${pollerResponse}, error) {\n`;
text += `\t\treturn poller.pollUntilDone(ctx, frequency)\n`;
text += `\t}\n`;
text += `\treturn result, nil\n`;
Expand Down Expand Up @@ -653,9 +651,8 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa
text += '}\n\n';
return text;
}
const generateResponseUnmarshaller = function (response: Response): string {
const generateResponseUnmarshaller = function (response: Response, isLRO: boolean): string {
let unmarshallerText = '';
const isLRO = isLROOperation(op);
if (!isSchemaResponse(response)) {
if (isLRO) {
unmarshallerText += '\treturn &HTTPPollerResponse{RawResponse: resp.Response}, nil\n';
Expand Down Expand Up @@ -698,8 +695,7 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa
return unmarshallerText;
}
const schemaResponse = <SchemaResponse>response;
// TODO remove paging skip when adding lro + pagers
if (isLRO && !isPageableOperation(op)) {
if (isLRO) {
unmarshallerText += `\treturn &${schemaResponse.schema.language.go!.lroResponseType.language.go!.name}{RawResponse: resp.Response}, nil\n`;
return unmarshallerText;
}
Expand Down Expand Up @@ -739,12 +735,28 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa
text += `\tif !resp.HasStatusCode(${formatStatusCodes(statusCodes)}) {\n`;
text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`;
text += '\t}\n';
text += generateResponseUnmarshaller(op.responses![0]);
if (isLROOperation(op) && isPageableOperation(op)) {
text += generateResponseUnmarshaller(op.responses![0], true);
text += '}\n\n';
text += `${comment(name, '// ')} handles the ${info.name} response.\n`;
text += `func (client *${client}) ${camelCase(op.language.go!.pageableType.name)}HandleResponse(resp *azcore.Response) (*${(<SchemaResponse>op.responses![0]).schema.language.go!.responseType.value}Response, error) {\n`;
const index = statusCodes.indexOf('204');
if (index > -1) {
statusCodes.splice(index, 1);
}
statusCodes.push('200');
text += `\tif !resp.HasStatusCode(${formatStatusCodes(statusCodes)}) {\n`;
text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`;
text += '\t}\n';
text += generateResponseUnmarshaller(op.responses![0], false);
} else {
text += generateResponseUnmarshaller(op.responses![0], isLROOperation(op));
}
} else {
text += '\tswitch resp.StatusCode {\n';
for (const response of values(op.responses)) {
text += `\tcase ${formatStatusCodes(response.protocol.http!.statusCodes)}:\n`
text += generateResponseUnmarshaller(response);
text += generateResponseUnmarshaller(response, isLROOperation(op));
}
text += '\tdefault:\n';
text += `\t\treturn nil, client.${info.protocolNaming.errorMethod}(resp)\n`;
Expand Down Expand Up @@ -876,7 +888,7 @@ function createInterfaceDefinition(group: OperationGroup, imports: ImportManager
const returns = generateReturnsInfo(op, false);
interfaceText += `\t${opName}(${getAPIParametersSig(op, imports)}) (${returns.join(', ')})\n`;
// Add resume LRO poller method for each Begin poller method
if (isLROOperation(op) && !op.extensions!['x-ms-pageable']) {
if (isLROOperation(op)) {
interfaceText += `\t// Resume${op.language.go!.name} - Used to create a new instance of this poller from the resume token of a previous instance of this poller type.\n`;
interfaceText += `\tResume${op.language.go!.name}(token string) (${op.language.go!.pollerType.name}, error)\n`;
}
Expand Down Expand Up @@ -971,7 +983,7 @@ function hasBinaryResponse(responses: Response[]): boolean {
function getAPIParametersSig(op: Operation, imports: ImportManager): string {
const methodParams = getMethodParameters(op);
const params = new Array<string>();
if (!isPageableOperation(op)) {
if (!isPageableOperation(op) || isLROOperation(op)) {
imports.add('context');
params.push('ctx context.Context');
}
Expand All @@ -981,59 +993,6 @@ function getAPIParametersSig(op: Operation, imports: ImportManager): string {
return params.join(', ');
}

// returns the parameters for the internal request creator method.
// e.g. "i int, s string"
function getCreateRequestParametersSig(op: Operation): string {
const methodParams = getMethodParameters(op);
const params = new Array<string>();
for (const methodParam of values(methodParams)) {
params.push(`${camelCase(methodParam.language.go!.name)} ${formatParameterTypeName(methodParam)}`);
}
return params.join(', ');
}

// returns the complete collection of method parameters
function getMethodParameters(op: Operation): Parameter[] {
const params = new Array<Parameter>();
const paramGroups = new Array<GroupProperty>();
for (const param of values(aggregateParameters(op))) {
if (param.implementation === ImplementationLocation.Client) {
// client params are passed via the receiver
continue;
} else if (param.schema.type === SchemaType.Constant) {
// don't generate a parameter for a constant
continue;
} else if (param.language.go!.paramGroup) {
// param groups will be added after individual params
if (!paramGroups.includes(param.language.go!.paramGroup)) {
paramGroups.push(param.language.go!.paramGroup);
}
continue;
}
params.push(param);
}
// move global optional params to the end of the slice
params.sort(sortParametersByRequired);
// add any parameter groups. optional group goes last
paramGroups.sort((a: GroupProperty, b: GroupProperty) => {
if (a.required === b.required) {
return 0;
}
if (a.required && !b.required) {
return -1;
}
return 1;
})
for (const paramGroup of values(paramGroups)) {
let name = camelCase(paramGroup.language.go!.name);
if (!paramGroup.required) {
name = 'options';
}
params.push(paramGroup);
}
return params;
}

// returns the return signature where each entry is the type name
// e.g. [ '*string', 'error' ]
function generateReturnsInfo(op: Operation, forHandler: boolean): string[] {
Expand All @@ -1052,8 +1011,7 @@ function generateReturnsInfo(op: Operation, forHandler: boolean): string[] {
returnType = op.language.go!.pageableType.name;
} else if (isSchemaResponse(firstResp)) {
returnType = '*' + firstResp.schema.language.go!.responseType.name;
// TODO remove paging skip when adding LRO + pagers
if (isLROOperation(op) && !isPageableOperation(op)) {
if (isLROOperation(op)) {
returnType = '*' + firstResp.schema.language.go!.lroResponseType.language.go!.name;
}
} else if (isLROOperation(op)) {
Expand Down
21 changes: 18 additions & 3 deletions src/generator/pagers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ export async function generatePagers(session: Session<CodeModel>): Promise<strin
const responseType = schemaResponse.schema.language.go!.responseType.name;
const resultType = schemaResponse.schema.language.go!.name;
let resultTypeName = resultType;
let pollerRespField = '';
let respFieldCheck = '\tresp, err := p.pipeline.Do(ctx, p.request)';
if (pager.respField) {
pollerRespField = `
// previous response from the endpoint
resp *azcore.Response`;
respFieldCheck =
`resp := p.resp
var err error
if resp == nil {
resp, err = p.pipeline.Do(ctx, p.request)
} else {
p.resp = nil
}`;
}
if (schemaResponse.schema.serialization?.xml?.name) {
// xml can specifiy its own name, prefer that if available
resultTypeName = schemaResponse.schema.serialization.xml.name;
Expand Down Expand Up @@ -67,7 +82,7 @@ type ${pagerType} struct {
// contains the current response
current *${responseType}
// any error encountered
err error
err error${pollerRespField}
}
func (p *${pagerType}) Err() error {
Expand All @@ -85,8 +100,8 @@ func (p *${pagerType}) NextPage(ctx context.Context) bool {
return false
}
p.request = req
}
resp, err := p.pipeline.Do(ctx, p.request)
}
${respFieldCheck}
if err != nil {
p.err = err
return false
Expand Down
Loading

0 comments on commit 7f7a951

Please sign in to comment.