From b5fe01f5814b04586e649f6f5153b87a7e425041 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 19 May 2020 08:11:40 -0700 Subject: [PATCH] Refactor discriminated type generation (#397) * Refactor discriminated type generation Use composition for discriminated types hierarchy. Change marker method to an exported GetType method used to retrieve the base-type content. Removed internal enums for discriminator values. Split marshallers and unmarshallers for discriminated types separate from counterparts for time and XML wrappers. * Fix time.Time marshalling for discriminated types * throw an error if failed to find discriminated type * early exit for unmarshaller corner case --- src/generator/helpers.ts | 5 +- src/generator/models.ts | 478 +++++++++------ src/generator/operations.ts | 2 +- src/generator/polymorphics.ts | 19 +- src/transform/namer.ts | 2 +- src/transform/transform.ts | 109 ++-- .../autorest/complexgroup/inheritance_test.go | 68 ++- .../complexgroup/polymorphicrecursive_test.go | 224 ++++--- .../complexgroup/polymorphism_test.go | 303 ++++++---- .../autorest/generated/complexgroup/models.go | 555 +++++++----------- .../complexgroup/polymorphic_helpers.go | 41 +- .../httpinfrastructuregroup/models.go | 2 +- test/autorest/generated/lrogroup/models.go | 18 +- 13 files changed, 930 insertions(+), 896 deletions(-) diff --git a/src/generator/helpers.ts b/src/generator/helpers.ts index a94b27259..62a40de90 100644 --- a/src/generator/helpers.ts +++ b/src/generator/helpers.ts @@ -7,7 +7,6 @@ import { Session } from '@azure-tools/autorest-extension-base'; import { comment } from '@azure-tools/codegen'; import { ArraySchema, CodeModel, DictionarySchema, Language, Parameter, Schema, SchemaType } from '@azure-tools/codemodel'; - // returns the common source-file preamble (license comment, package name etc) export async function contentPreamble(session: Session): Promise { const headerText = comment(await session.getValue('header-text', 'MISSING LICENSE HEADER'), '// '); @@ -70,8 +69,8 @@ export function substituteDiscriminator(schema: Schema): string { const dictElem = dictSchema.elementType; return `map[string]${substituteDiscriminator(dictElem)}`; case SchemaType.Object: - if (schema.language.go!.discriminator) { - return schema.language.go!.discriminator; + if (schema.language.go!.discriminatorInterface) { + return schema.language.go!.discriminatorInterface; } return schema.language.go!.name; default: diff --git a/src/generator/models.ts b/src/generator/models.ts index 5c06de8e3..57bc48682 100644 --- a/src/generator/models.ts +++ b/src/generator/models.ts @@ -19,7 +19,9 @@ export async function generateModels(session: Session): Promise>session.model.language.go!.responseSchemas; for (const schema of values(responseSchemas)) { - structs.push(generateStruct(schema.language.go!.responseType, schema.language.go!.properties)); + const respType = generateStruct(schema.language.go!.responseType, schema.language.go!.properties); + generateUnmarshallerForResponseEnvelope(respType); + structs.push(respType); } const paramGroups = >session.model.language.go!.parameterGroups; for (const paramGroup of values(paramGroups)) { @@ -36,10 +38,9 @@ export async function generateModels(session: Session): Promise { return sortAscending(a.name, b.name) }); for (const method of values(struct.Methods)) { - text += method; + text += method.text; } } return text; @@ -48,12 +49,18 @@ export async function generateModels(session: Session): Promise { return sortAscending(a.language.go!.name, b.language.go!.name); }); } - this.Methods = new Array(); + this.Methods = new Array(); + this.ComposedOf = new Array(); } text(): string { @@ -74,6 +82,10 @@ class StructDef { text += `${comment(this.Language.description, '// ')}\n`; } text += `type ${this.Language.name} struct {\n`; + // any composed types go first + for (const comp of values(this.ComposedOf)) { + text += `\t${comp.language.go!.name}\n`; + } // used to track when to add an extra \n between fields that have comments let first = true; for (const prop of values(this.Properties)) { @@ -144,7 +156,7 @@ class StructDef { tag = ''; } let pointer = '*'; - if (prop.schema.language.go!.discriminator) { + if (prop.schema.language.go!.discriminatorInterface) { // pointer-to-interface introduces very clunky code pointer = ''; } @@ -162,7 +174,7 @@ class StructDef { text += `\t${comment(param.language.go!.description, '// ')}\n`; } let pointer = '*'; - if (param.required || param.schema.language.go!.discriminator) { + if (param.required || param.schema.language.go!.discriminatorInterface) { // pointer-to-interface introduces very clunky code pointer = ''; } @@ -172,185 +184,19 @@ class StructDef { return text; } - // creates a custom marshaller for this type - marshaller(): string { - // only needed for discriminated types, types with time.Time or where the XML name doesn't match the type name - if (this.Language.needsDateTimeMarshalling === undefined && this.Language.xmlWrapperName === undefined && - this.Language.discriminatorEnum === undefined) { - return ''; - } - const receiver = this.Language.name[0].toLowerCase(); - let formatSig = 'JSON() ([]byte, error)'; - if (this.Language.marshallingFormat === 'xml') { - formatSig = 'XML(e *xml.Encoder, start xml.StartElement) error' - } - let text = `func (${receiver} ${this.Language.name}) Marshal${formatSig} {\n`; - if (this.Language.xmlWrapperName) { - text += `\tstart.Name.Local = "${this.Language.xmlWrapperName}"\n`; - } else if (this.Language.discriminatorEnum) { - // find the discriminator property - for (const prop of values(this.Properties)) { - if (prop.isDiscriminator) { - text += `\t${receiver}.${prop.language.go!.name} = `; - if (this.Language.discriminatorRealEnum) { - text += `${this.Language.discriminatorEnum}.ToPtr()\n`; - } else { - text += `strptr(${this.Language.discriminatorEnum})\n`; - } - break; - } - } - } - text += this.generateAliasType(receiver, true); - if (this.Language.marshallingFormat === 'json') { - text += '\treturn json.Marshal(aux)\n'; - } else { - text += '\treturn e.EncodeElement(aux, start)\n'; - } - text += '}\n\n'; - return text; - } - - // creates a custom unmarshaller for this type - unmarshaller(): string { - // only needed for discriminated types, types containing discriminated types, or types with time.Time - const hasPolymorphicField = values(this.Properties).first((each: Property) => { - if (isObjectSchema(each.schema)) { - return each.schema.discriminator !== undefined; - } - return false; - }); - if (this.Language.discriminatorEnum === undefined && !hasPolymorphicField && this.Language.needsDateTimeMarshalling === undefined) { - return ''; - } - const receiver = this.Language.name[0].toLowerCase(); - let formatSig = 'JSON(data []byte)'; - if (this.Language.marshallingFormat === 'xml') { - formatSig = 'XML(d *xml.Decoder, start xml.StartElement)'; - } - let text = `func (${receiver} *${this.Language.name}) Unmarshal${formatSig} error {\n`; - if (this.Language.discriminatorEnum || hasPolymorphicField) { - if (this.Language.responseType === true) { - // add a custom unmarshaller to the response envelope - // find the discriminated type field - let field = 'FIND'; - let type = 'FIND'; - for (const prop of values(this.Properties)) { - if (prop.isDiscriminator) { - field = prop.language.go!.name; - type = prop.schema.language.go!.discriminator; - break; - } - } - text += `\tt, err := unmarshal${type}(data)\n`; - text += '\tif err != nil {\n'; - text += '\t\treturn err\n'; - text += '\t}\n'; - text += `\t${receiver}.${field} = t\n`; - } else { - // polymorphic type, or type containing a polymorphic type - text += '\tvar rawMsg map[string]*json.RawMessage\n'; - text += '\tif err := json.Unmarshal(data, &rawMsg); err != nil {\n'; - text += '\t\treturn err\n'; - text += '\t}\n'; - text += '\tfor k, v := range rawMsg {\n'; - text += '\t\tvar err error\n'; - text += '\t\tswitch k {\n'; - // unmarshal each field one by one - for (const prop of values(this.Properties)) { - text += `\t\tcase "${prop.serializedName}":\n`; - text += '\t\t\tif v != nil {\n'; - if (prop.schema.language.go!.discriminator) { - text += `\t\t\t\t${receiver}.${prop.language.go!.name}, err = unmarshal${prop.schema.language.go!.discriminator}(*v)\n`; - } else if (isArraySchema(prop.schema) && prop.schema.elementType.language.go!.discriminator) { - text += `\t\t\t\t${receiver}.${prop.language.go!.name}, err = unmarshal${prop.schema.elementType.language.go!.discriminator}Array(*v)\n`; - } else if (prop.schema.language.go!.internalTimeType) { - text += `\t\t\t\tvar aux ${prop.schema.language.go!.internalTimeType}\n`; - text += '\t\t\t\terr = json.Unmarshal(*v, &aux)\n'; - text += `\t\t\t\t${receiver}.${prop.language.go!.name} = (*time.Time)(&aux)\n`; - } else { - text += `\t\t\t\terr = json.Unmarshal(*v, &${receiver}.${prop.language.go!.name})\n`; - } - text += '\t\t\t}\n'; - } - text += '\t\t}\n'; - text += '\t\tif err != nil {\n'; - text += '\t\t\treturn err\n'; - text += '\t\t}\n'; - text += '\t}\n'; - } - } else { - // non-polymorphic case, must be something with time.Time - text += this.generateAliasType(receiver, false); - if (this.Language.marshallingFormat === 'json') { - text += '\tif err := json.Unmarshal(data, aux); err != nil {\n'; - text += '\t\treturn err\n'; - text += '\t}\n'; - } else { - text += '\tif err := d.DecodeElement(aux, &start); err != nil {\n'; - text += '\t\treturn err\n'; - text += '\t}\n'; - } - for (const prop of values(this.Properties)) { - if (prop.schema.type !== SchemaType.DateTime) { - continue; - } - text += `\t${receiver}.${prop.language.go!.name} = (*time.Time)(aux.${prop.language.go!.name})\n`; - } - } - text += '\treturn nil\n'; - text += '}\n\n'; - return text; - } - discriminator(): string { - if (!this.Language.discriminator) { + if (!this.Language.discriminatorInterface) { return ''; } - let text = `// ${this.Language.discriminator} provides polymorphic access to related types.\n`; - text += `type ${this.Language.discriminator} interface {\n`; + let text = `// ${this.Language.discriminatorInterface} provides polymorphic access to related types.\n`; + text += `type ${this.Language.discriminatorInterface} interface {\n`; if (this.Language.discriminatorParent) { text += `\t${this.Language.discriminatorParent}\n`; } - text += `\t${camelCase(this.Language.discriminator)}()\n`; + text += `\tGet${this.Language.name}() *${this.Language.name}\n`; text += '}\n\n'; return text; } - - // generates an alias type used by custom marshaller/unmarshaller - private generateAliasType(receiver: string, forMarshal: boolean): string { - let text = `\ttype alias ${this.Language.name}\n`; - text += `\taux := &struct {\n`; - text += `\t\t*alias\n`; - for (const prop of values(this.Properties)) { - if (prop.schema.type !== SchemaType.DateTime) { - continue; - } - let sn = prop.serializedName; - if (prop.schema.serialization?.xml?.name) { - // xml can specifiy its own name, prefer that if available - sn = prop.schema.serialization.xml.name; - } - text += `\t\t${prop.language.go!.name} *${prop.schema.language.go!.internalTimeType} \`${this.Language.marshallingFormat}:"${sn}"\`\n`; - } - text += `\t}{\n`; - let rec = receiver; - if (forMarshal) { - rec = '&' + rec; - } - text += `\t\talias: (*alias)(${rec}),\n`; - if (forMarshal) { - // emit code to initialize time fields - for (const prop of values(this.Properties)) { - if (prop.schema.type !== SchemaType.DateTime) { - continue; - } - text += `\t\t${prop.language.go!.name}: (*${prop.schema.language.go!.internalTimeType})(${receiver}.${prop.language.go!.name}),\n`; - } - } - text += `\t}\n`; - return text; - } } function generateStructs(objects?: ObjectSchema[]): StructDef[] { @@ -361,15 +207,21 @@ function generateStructs(objects?: ObjectSchema[]): StructDef[] { for (const prop of values(obj.properties)) { props.push(prop); } - // now add all parent properties - for (const parent of values(obj.parents?.all)) { + const structDef = generateStruct(obj.language.go!, props); + // now add the parent type + let parentType: ObjectSchema | undefined; + for (const parent of values(obj.parents?.immediate)) { if (isObjectSchema(parent)) { - for (const prop of values(parent.properties)) { - props.push(prop); - } + parentType = parent; + structDef.ComposedOf.push(parent); } } - const structDef = generateStruct(obj.language.go!, props); + const hasPolymorphicField = values(obj.properties).first((each: Property) => { + if (isObjectSchema(each.schema)) { + return each.schema.discriminator !== undefined; + } + return false; + }); if (obj.language.go!.errorType) { // add Error() method let text = `func (e ${obj.language.go!.name}) Error() string {\n`; @@ -384,16 +236,28 @@ function generateStructs(objects?: ObjectSchema[]): StructDef[] { text += '\t}\n'; text += '\treturn msg\n'; text += '}\n\n'; - structDef.Methods.push(text); - } else if (obj.language.go!.polymorphicInterfaces) { - // generate interface method(s) - const interfaces = >obj.language.go!.polymorphicInterfaces; - interfaces.sort(sortAscending); - for (const iface of values(interfaces)) { - const marker = `func (*${obj.language.go!.name}) ${camelCase(iface)}() {}\n\n`; - structDef.Methods.push(marker); + structDef.Methods.push({ name: 'Error', text: text }); + } else if (obj.discriminator) { + // only need to generate interface method and internal marshaller for discriminators (Fish, Salmon, Shark) + generateDiscriminatorMethods(obj, structDef, parentType!); + // the root type doesn't get a marshaller as callers don't instantiate instances of it + if (!obj.language.go!.rootDiscriminator) { + generateDiscriminatedTypeMarshaller(obj, structDef, parentType!); + } + generateDiscriminatedTypeUnmarshaller(obj, structDef, parentType!); + } else if (obj.discriminatorValue) { + generateDiscriminatedTypeMarshaller(obj, structDef, parentType!); + generateDiscriminatedTypeUnmarshaller(obj, structDef, parentType!); + } else if (hasPolymorphicField) { + generateDiscriminatedTypeUnmarshaller(obj, structDef, parentType!); + } else if (obj.language.go!.needsDateTimeMarshalling || obj.language.go!.xmlWrapperName) { + // TODO: unify marshalling schemes? + generateMarshaller(structDef); + if (obj.language.go!.needsDateTimeMarshalling) { + generateUnmarshaller(structDef); } } + structDef.ComposedOf.sort((a: ObjectSchema, b: ObjectSchema) => { return sortAscending(a.language.go!.name, b.language.go!.name); }); structTypes.push(structDef); } return structTypes; @@ -423,3 +287,233 @@ function generateParamGroupStruct(lang: Language, params: Parameter[]): StructDe } return st; } + +function generateUnmarshallerForResponseEnvelope(structDef: StructDef) { + // if the response envelope contains a discriminated type we need an unmarshaller + let found = false; + for (const prop of values(structDef.Properties)) { + if (prop.isDiscriminator) { + found = true; + break; + } + } + if (!found) { + return; + } + const receiver = structDef.Language.name[0].toLowerCase(); + let unmarshaller = `func (${receiver} *${structDef.Language.name}) UnmarshalJSON(data []byte) error {\n`; + // add a custom unmarshaller to the response envelope + // find the discriminated type field + let field = ''; + let type = ''; + for (const prop of values(structDef.Properties)) { + if (prop.isDiscriminator) { + field = prop.language.go!.name; + type = prop.schema.language.go!.discriminatorInterface; + break; + } + } + if (field === '' || type === '') { + throw console.error(`failed to the discriminated type field for response envelope ${structDef.Language.name}`); + } + unmarshaller += `\tt, err := unmarshal${type}(data)\n`; + unmarshaller += '\tif err != nil {\n'; + unmarshaller += '\t\treturn err\n'; + unmarshaller += '\t}\n'; + unmarshaller += `\t${receiver}.${field} = t\n`; + unmarshaller += '\treturn nil\n'; + unmarshaller += '}\n\n'; + structDef.Methods.push({ name: 'UnmarshalJSON', text: unmarshaller }); +} + +function generateDiscriminatorMethods(obj: ObjectSchema, structDef: StructDef, parentType: ObjectSchema) { + const typeName = obj.language.go!.name; + const receiver = typeName[0].toLowerCase(); + // generate interface method + const interfaceMethod = `Get${typeName}`; + const method = `func (${receiver} *${typeName}) ${interfaceMethod}() *${typeName} { return ${receiver} }\n\n`; + structDef.Methods.push({ name: interfaceMethod, text: method }); + // generate internal marshaller method + const paramType = obj.discriminator!.property.schema.language.go!.name; + const paramName = 'discValue'; + let marshalInteral = `func (${receiver} ${typeName}) marshalInternal(${paramName} ${paramType}) map[string]interface{} {\n`; + if (parentType) { + marshalInteral += `\tobjectMap := ${receiver}.${parentType.language.go!.name}.marshalInternal(${paramName})\n`; + } else { + marshalInteral += '\tobjectMap := make(map[string]interface{})\n'; + } + for (const prop of values(structDef.Properties)) { + if (prop.isDiscriminator) { + marshalInteral += `\t${receiver}.${prop.language.go!.name} = &${paramName}\n`; + marshalInteral += `\tobjectMap["${prop.serializedName}"] = ${receiver}.${prop.language.go!.name}\n`; + } else { + marshalInteral += `\tif ${receiver}.${prop.language.go!.name} != nil {\n`; + if (prop.schema.language.go!.internalTimeType) { + marshalInteral += `\t\tobjectMap["${prop.serializedName}"] = (*${prop.schema.language.go!.internalTimeType})(${receiver}.${prop.language.go!.name})\n`; + } else { + marshalInteral += `\t\tobjectMap["${prop.serializedName}"] = ${receiver}.${prop.language.go!.name}\n`; + } + marshalInteral += `\t}\n`; + } + } + marshalInteral += '\treturn objectMap\n'; + marshalInteral += '}\n\n'; + structDef.Methods.push({ name: 'marshalInternal', text: marshalInteral }); +} + +function generateDiscriminatedTypeMarshaller(obj: ObjectSchema, structDef: StructDef, parentType: ObjectSchema) { + const typeName = structDef.Language.name; + const receiver = typeName[0].toLowerCase(); + // generate marshaller method + let marshaller = `func (${receiver} ${typeName}) MarshalJSON() ([]byte, error) {\n`; + marshaller += `\tobjectMap := ${receiver}.${parentType!.language.go!.name}.marshalInternal(${obj.discriminatorValue})\n`; + for (const prop of values(structDef.Properties)) { + marshaller += `\tif ${receiver}.${prop.language.go!.name} != nil {\n`; + if (prop.schema.language.go!.internalTimeType) { + marshaller += `\t\tobjectMap["${prop.serializedName}"] = (*${prop.schema.language.go!.internalTimeType})(${receiver}.${prop.language.go!.name})\n`; + } else { + marshaller += `\t\tobjectMap["${prop.serializedName}"] = ${receiver}.${prop.language.go!.name}\n`; + } + marshaller += `\t}\n`; + } + marshaller += '\treturn json.Marshal(objectMap)\n'; + marshaller += '}\n\n'; + structDef.Methods.push({ name: 'MarshalJSON', text: marshaller }); +} + +function generateDiscriminatedTypeUnmarshaller(obj: ObjectSchema, structDef: StructDef, parentType?: ObjectSchema) { + // there's a corner-case where a derived type might not add any new fields (Cookiecuttershark). + // in this case skip adding the unmarshaller as it's not necessary and doesn't compile. + if (!structDef.Properties || structDef.Properties.length === 0) { + return; + } + const typeName = structDef.Language.name; + const receiver = typeName[0].toLowerCase(); + let unmarshaller = `func (${receiver} *${typeName}) UnmarshalJSON(data []byte) error {\n`; + // polymorphic type, or type containing a polymorphic type + unmarshaller += '\tvar rawMsg map[string]*json.RawMessage\n'; + unmarshaller += '\tif err := json.Unmarshal(data, &rawMsg); err != nil {\n'; + unmarshaller += '\t\treturn err\n'; + unmarshaller += '\t}\n'; + unmarshaller += '\tfor k, v := range rawMsg {\n'; + unmarshaller += '\t\tvar err error\n'; + unmarshaller += '\t\tswitch k {\n'; + // unmarshal each field one by one + for (const prop of values(structDef.Properties)) { + unmarshaller += `\t\tcase "${prop.serializedName}":\n`; + unmarshaller += '\t\t\tif v != nil {\n'; + if (prop.schema.language.go!.discriminatorInterface) { + unmarshaller += `\t\t\t\t${receiver}.${prop.language.go!.name}, err = unmarshal${prop.schema.language.go!.discriminatorInterface}(*v)\n`; + } else if (isArraySchema(prop.schema) && prop.schema.elementType.language.go!.discriminatorInterface) { + unmarshaller += `\t\t\t\t${receiver}.${prop.language.go!.name}, err = unmarshal${prop.schema.elementType.language.go!.discriminatorInterface}Array(*v)\n`; + } else if (prop.schema.language.go!.internalTimeType) { + unmarshaller += `\t\t\t\tvar aux ${prop.schema.language.go!.internalTimeType}\n`; + unmarshaller += '\t\t\t\terr = json.Unmarshal(*v, &aux)\n'; + unmarshaller += `\t\t\t\t${receiver}.${prop.language.go!.name} = (*time.Time)(&aux)\n`; + } else { + unmarshaller += `\t\t\t\terr = json.Unmarshal(*v, &${receiver}.${prop.language.go!.name})\n`; + } + unmarshaller += '\t\t\t}\n'; + } + unmarshaller += '\t\t}\n'; + unmarshaller += '\t\tif err != nil {\n'; + unmarshaller += '\t\t\treturn err\n'; + unmarshaller += '\t\t}\n'; + unmarshaller += '\t}\n'; + if (!obj.language.go!.rootDiscriminator && parentType) { + unmarshaller += `\treturn json.Unmarshal(data, &${receiver}.${parentType.language.go!.name})\n`; + } else { + unmarshaller += '\treturn nil\n'; + } + unmarshaller += '}\n\n'; + structDef.Methods.push({ name: 'UnmarshalJSON', text: unmarshaller }); +} + +function generateMarshaller(structDef: StructDef) { + // only needed for types with time.Time or where the XML name doesn't match the type name + const receiver = structDef.Language.name[0].toLowerCase(); + let formatSig = 'JSON() ([]byte, error)'; + let methodName = 'MarshalJSON'; + if (structDef.Language.marshallingFormat === 'xml') { + formatSig = 'XML(e *xml.Encoder, start xml.StartElement) error' + methodName = 'MarshalXML'; + } + let text = `func (${receiver} ${structDef.Language.name}) Marshal${formatSig} {\n`; + if (structDef.Language.xmlWrapperName) { + text += `\tstart.Name.Local = "${structDef.Language.xmlWrapperName}"\n`; + } + text += generateAliasType(structDef, receiver, true); + if (structDef.Language.marshallingFormat === 'json') { + text += '\treturn json.Marshal(aux)\n'; + } else { + text += '\treturn e.EncodeElement(aux, start)\n'; + } + text += '}\n\n'; + structDef.Methods.push({ name: methodName, text: text }); +} + +function generateUnmarshaller(structDef: StructDef) { + // non-polymorphic case, must be something with time.Time + const receiver = structDef.Language.name[0].toLowerCase(); + let formatSig = 'JSON(data []byte)'; + let methodName = 'UnmarshalJSON'; + if (structDef.Language.marshallingFormat === 'xml') { + formatSig = 'XML(d *xml.Decoder, start xml.StartElement)'; + methodName = 'UnmarshalXML'; + } + let text = `func (${receiver} *${structDef.Language.name}) Unmarshal${formatSig} error {\n`; + text += generateAliasType(structDef, receiver, false); + if (structDef.Language.marshallingFormat === 'json') { + text += '\tif err := json.Unmarshal(data, aux); err != nil {\n'; + text += '\t\treturn err\n'; + text += '\t}\n'; + } else { + text += '\tif err := d.DecodeElement(aux, &start); err != nil {\n'; + text += '\t\treturn err\n'; + text += '\t}\n'; + } + for (const prop of values(structDef.Properties)) { + if (prop.schema.type !== SchemaType.DateTime) { + continue; + } + text += `\t${receiver}.${prop.language.go!.name} = (*time.Time)(aux.${prop.language.go!.name})\n`; + } + text += '\treturn nil\n'; + text += '}\n\n'; + structDef.Methods.push({ name: methodName, text: text }); +} + +// generates an alias type used by custom marshaller/unmarshaller +function generateAliasType(structDef: StructDef, receiver: string, forMarshal: boolean): string { + let text = `\ttype alias ${structDef.Language.name}\n`; + text += `\taux := &struct {\n`; + text += `\t\t*alias\n`; + for (const prop of values(structDef.Properties)) { + if (prop.schema.type !== SchemaType.DateTime) { + continue; + } + let sn = prop.serializedName; + if (prop.schema.serialization?.xml?.name) { + // xml can specifiy its own name, prefer that if available + sn = prop.schema.serialization.xml.name; + } + text += `\t\t${prop.language.go!.name} *${prop.schema.language.go!.internalTimeType} \`${structDef.Language.marshallingFormat}:"${sn}"\`\n`; + } + text += `\t}{\n`; + let rec = receiver; + if (forMarshal) { + rec = '&' + rec; + } + text += `\t\talias: (*alias)(${rec}),\n`; + if (forMarshal) { + // emit code to initialize time fields + for (const prop of values(structDef.Properties)) { + if (prop.schema.type !== SchemaType.DateTime) { + continue; + } + text += `\t\t${prop.language.go!.name}: (*${prop.schema.language.go!.internalTimeType})(${receiver}.${prop.language.go!.name}),\n`; + } + } + text += `\t}\n`; + return text; +} diff --git a/src/generator/operations.ts b/src/generator/operations.ts index 334d0b348..9a3baca2a 100644 --- a/src/generator/operations.ts +++ b/src/generator/operations.ts @@ -673,7 +673,7 @@ function createProtocolResponse(client: string, op: Operation, imports: ImportMa } let target = `result.${schemaResponse.schema.language.go!.responseType.value}`; // when unmarshalling a wrapped XML array or discriminated type, unmarshal into the response type, not the field - if ((mediaType === 'XML' && schemaResponse.schema.type === SchemaType.Array) || schemaResponse.schema.language.go!.discriminator) { + if ((mediaType === 'XML' && schemaResponse.schema.type === SchemaType.Array) || schemaResponse.schema.language.go!.discriminatorInterface) { target = 'result'; } text += `\treturn &result, resp.UnmarshalAs${mediaType}(&${target})\n`; diff --git a/src/generator/polymorphics.ts b/src/generator/polymorphics.ts index 71000494a..96f0a11f9 100644 --- a/src/generator/polymorphics.ts +++ b/src/generator/polymorphics.ts @@ -4,7 +4,6 @@ *--------------------------------------------------------------------------------------------*/ import { Session } from '@azure-tools/autorest-extension-base'; -import { camelCase, pascalCase } from '@azure-tools/codegen'; import { CodeModel, ObjectSchema } from '@azure-tools/codemodel'; import { values } from '@azure-tools/linq'; import { contentPreamble, sortAscending } from './helpers'; @@ -21,24 +20,12 @@ export async function generatePolymorphicHelpers(session: Session): P imports.add('encoding/json'); text += imports.text(); const discriminators = >session.model.language.go!.discriminators; - discriminators.sort((a: ObjectSchema, b: ObjectSchema) => { return sortAscending(a.language.go!.discriminator, b.language.go!.discriminator) }); + discriminators.sort((a: ObjectSchema, b: ObjectSchema) => { return sortAscending(a.language.go!.discriminatorInterface, b.language.go!.discriminatorInterface) }); for (const disc of values(discriminators)) { // this is used to track any sub-hierarchies (SalmonType, SharkType in the test server) const roots = new Array(); roots.push(disc); - if (disc.language.go!.discriminatorEnumNeeded) { - // constant definition - // only generate one set from the root as it contains all possible values - text += 'const (\n'; - // TODO: sort - for (const val of values(disc.discriminator!.all)) { - const objSchema = val; - text += `\t${objSchema.language.go!.discriminatorEnum} = "${objSchema.discriminatorValue!}"\n`; - } - text += ')\n\n'; - } - // add sub-hierarchies for (const val of values(disc.discriminator!.all)) { const objSchema = val; @@ -49,7 +36,7 @@ export async function generatePolymorphicHelpers(session: Session): P // generate unmarshallers for each discriminator for (const root of values(roots)) { - const discName = root.language.go!.discriminator; + const discName = root.language.go!.discriminatorInterface; // scalar unmarshaller text += `func unmarshal${discName}(body []byte) (${discName}, error) {\n`; text += '\tvar m map[string]interface{}\n'; @@ -60,7 +47,7 @@ export async function generatePolymorphicHelpers(session: Session): P text += `\tswitch m["${root.discriminator!.property.serializedName}"] {\n`; for (const val of values(root.discriminator!.all)) { const objSchema = val; - text += `\tcase ${val.language.go!.discriminatorEnum}:\n`; + text += `\tcase ${objSchema.discriminatorValue}:\n`; text += `\t\tb = &${val.language.go!.name}{}\n`; } text += '\tdefault:\n'; diff --git a/src/transform/namer.ts b/src/transform/namer.ts index f8bd26dfb..270fb38eb 100644 --- a/src/transform/namer.ts +++ b/src/transform/namer.ts @@ -57,7 +57,7 @@ export async function namer(session: Session) { details.name = getEscapedReservedName(capitalizeAcronyms(pascalCase(details.name)), 'Model'); if (obj.discriminator) { // if this is a discriminator add the interface name - details.discriminator = createPolymorphicInterfaceName(details.name); + details.discriminatorInterface = createPolymorphicInterfaceName(details.name); } for (const prop of values(obj.properties)) { const details = prop.language.go; diff --git a/src/transform/transform.ts b/src/transform/transform.ts index 96f2be0e5..e924db7bb 100644 --- a/src/transform/transform.ts +++ b/src/transform/transform.ts @@ -5,10 +5,10 @@ import { camelCase, KnownMediaType, pascalCase, serialize } from '@azure-tools/codegen'; import { Host, startSession, Session } from '@azure-tools/autorest-extension-base'; -import { ObjectSchema, ArraySchema, codeModelSchema, ChoiceValue, CodeModel, DateTimeSchema, GroupProperty, HttpHeader, HttpResponse, ImplementationLocation, Language, OperationGroup, SchemaType, NumberSchema, Operation, SchemaResponse, Parameter, Property, Protocols, Schema, DictionarySchema, Protocol, ChoiceSchema, SealedChoiceSchema, ConstantSchema } from '@azure-tools/codemodel'; +import { ObjectSchema, ArraySchema, ChoiceValue, codeModelSchema, CodeModel, DateTimeSchema, GroupProperty, HttpHeader, HttpResponse, ImplementationLocation, Language, OperationGroup, SchemaType, NumberSchema, Operation, SchemaResponse, Parameter, Property, Protocols, Schema, DictionarySchema, Protocol, ChoiceSchema, SealedChoiceSchema, ConstantSchema } from '@azure-tools/codemodel'; import { items, values } from '@azure-tools/linq'; import { aggregateParameters, isPageableOperation, isObjectSchema, isSchemaResponse, PagerInfo, isLROOperation, PollerInfo } from '../common/helpers'; -import { createPolymorphicInterfaceName, namer, removePrefix } from './namer'; +import { namer, removePrefix } from './namer'; // The transformer adds Go-specific information to the code model. export async function transform(host: Host) { @@ -38,14 +38,21 @@ async function process(session: Session) { // fix up struct field types for (const obj of values(session.model.schemas.objects)) { if (obj.discriminator) { - const discriminator = annotateDiscriminatedTypes(obj); - if (discriminator) { - // discriminators will contain the root type of each discriminated type hierarchy - if (!session.model.language.go!.discriminators) { - session.model.language.go!.discriminators = new Array(); + // discriminators will contain the root type of each discriminated type hierarchy + if (!session.model.language.go!.discriminators) { + session.model.language.go!.discriminators = new Array(); + } + const defs = >session.model.language.go!.discriminators; + const rootDiscriminator = getRootDiscriminator(obj); + if (defs.indexOf(rootDiscriminator) < 0) { + rootDiscriminator.language.go!.rootDiscriminator = true; + defs.push(rootDiscriminator); + // fix up discriminator value to use the enum type if available + const discriminatorEnums = getDiscriminatorEnums(rootDiscriminator); + // for each child type in the hierarchy, fix up the discriminator value + for (const child of values(rootDiscriminator.children?.all)) { + (child).discriminatorValue = getEnumForDiscriminatorValue((child).discriminatorValue!, discriminatorEnums); } - const defs = >session.model.language.go!.discriminators; - defs.push(discriminator); } } for (const prop of values(obj.properties)) { @@ -600,26 +607,14 @@ function generateResponseTypeName(schema: Schema): Language { } } -function annotateDiscriminatedTypes(obj: ObjectSchema): ObjectSchema | undefined { - if (obj.language.go!.polymorphicInterfaces !== undefined) { - // this hierarchy of discriminated types has already been processed - return; - } - // we have a type in the hierarchy of polymorphic types, it can be one of three things - // 1. root - no parent types, only child types - // 2. intermediate root - has a parent and also has children (salmon in the test server) - // 3. child - has parent and no children - // - // for cases #1 and #2 we need to generate an interface type, and for - // case #2 the generated interface must also contain the parent interface - // for case #3 all that's required is to generate the marker method on - // the child type(s) for the applicable interface. +function getRootDiscriminator(obj: ObjectSchema): ObjectSchema { + // discriminators can be a root or an "intermediate" root (Salmon in the test server) // walk to the root let root = obj; while (true) { if (!root.parents) { - // simple case, no parent types + // simple case, already at the root break; } for (const parent of values(root.parents?.immediate)) { @@ -627,6 +622,7 @@ function annotateDiscriminatedTypes(obj: ObjectSchema): ObjectSchema | undefined // e.g. if type Foo is in a DictionaryOfFoo, then one of // Foo's parents will be DictionaryOfFoo which we ignore. if (isObjectSchema(parent) && parent.discriminator) { + root.language.go!.discriminatorParent = parent.language.go!.discriminatorInterface; root = parent; } } @@ -638,58 +634,11 @@ function annotateDiscriminatedTypes(obj: ObjectSchema): ObjectSchema | undefined break; } } - // create the interface type name based on the current root - const rootType = root.language.go!.discriminator; - // use pre-defined enum values if available - const choices = getChoices(root); - if (!choices) { - // mark that we need to generate our own enum type - root.language.go!.discriminatorEnumNeeded = true; - } - recursiveAnnotateDiscriminatedTypes(root, rootType, rootType, choices); return root; } -function recursiveAnnotateDiscriminatedTypes(obj: ObjectSchema, rootInterface: string, currentInterface: string, choices: Array | undefined) { - if (!obj.language.go!.polymorphicInterfaces) { - obj.language.go!.polymorphicInterfaces = new Array(); - } - const interfaces = >obj.language.go!.polymorphicInterfaces; - interfaces.push(currentInterface); - // now walk all the children, annotating them with the interface - for (const child of values(obj.discriminator?.immediate)) { - const childSchema = child; - if (!childSchema.language.go!.polymorphicInterfaces) { - // copy parent's interfaces - childSchema.language.go!.polymorphicInterfaces = [...>obj.language.go!.polymorphicInterfaces]; - } - if (childSchema.discriminator && childSchema.discriminator.all) { - // case #2 - intermediate root - childSchema.language.go!.discriminatorParent = currentInterface; - recursiveAnnotateDiscriminatedTypes(childSchema, rootInterface, createPolymorphicInterfaceName(childSchema.language.go!.name), choices); - } - if (choices) { - // find the choice value that matches the current type's discriminator - let found = false; - for (const choice of values(choices)) { - if (choice.value === childSchema.discriminatorValue) { - childSchema.language.go!.discriminatorEnum = choice.language.go!.name; - childSchema.language.go!.discriminatorRealEnum = true; - found = true; - break; - } - } - if (!found) { - throw console.error(`failed to find discriminator choice value for type ${childSchema.language.go!.name}`); - } - } else { - // add the internal enum name for this sub-type - childSchema.language.go!.discriminatorEnum = `${camelCase(rootInterface)}${pascalCase(childSchema.discriminatorValue!)}`; - } - } -} - -function getChoices(obj: ObjectSchema): Array | undefined { +// returns the set of enum values used for discriminators +function getDiscriminatorEnums(obj: ObjectSchema): Array | undefined { if (obj.discriminator?.property.schema.type === SchemaType.Choice) { return (obj.discriminator!.property.schema).choices; } else if (obj.discriminator?.property.schema.type === SchemaType.SealedChoice) { @@ -697,3 +646,17 @@ function getChoices(obj: ObjectSchema): Array | undefined { } return undefined; } + +// returns the enum name for the specified discriminator value +function getEnumForDiscriminatorValue(discValue: string, enums: Array | undefined): string { + if (!enums) { + return `"${discValue}"`; + } + // find the choice value that matches the current type's discriminator + for (const enm of values(enums)) { + if (enm.value === discValue) { + return enm.language.go!.name; + } + } + throw console.error(`failed to find discriminator enum value for ${discValue}`); +} diff --git a/test/autorest/complexgroup/inheritance_test.go b/test/autorest/complexgroup/inheritance_test.go index 0340f7d16..412a1e09b 100644 --- a/test/autorest/complexgroup/inheritance_test.go +++ b/test/autorest/complexgroup/inheritance_test.go @@ -28,44 +28,60 @@ func TestInheritanceGetValid(t *testing.T) { t.Fatalf("GetValid: %v", err) } helpers.DeepEqualOrFatal(t, result.Siamese, &complexgroup.Siamese{ - Breed: to.StringPtr("persian"), - Color: to.StringPtr("green"), - Hates: &[]complexgroup.Dog{ - { - Food: to.StringPtr("tomato"), - ID: to.Int32Ptr(1), - Name: to.StringPtr("Potato"), + Cat: complexgroup.Cat{ + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(2), + Name: to.StringPtr("Siameeee"), }, - { - Food: to.StringPtr("french fries"), - ID: to.Int32Ptr(-1), - Name: to.StringPtr("Tomato"), + Color: to.StringPtr("green"), + Hates: &[]complexgroup.Dog{ + { + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(1), + Name: to.StringPtr("Potato"), + }, + Food: to.StringPtr("tomato"), + }, + { + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(-1), + Name: to.StringPtr("Tomato"), + }, + Food: to.StringPtr("french fries"), + }, }, }, - ID: to.Int32Ptr(2), - Name: to.StringPtr("Siameeee"), + Breed: to.StringPtr("persian"), }) } func TestInheritancePutValid(t *testing.T) { client := getInheritanceOperations(t) result, err := client.PutValid(context.Background(), complexgroup.Siamese{ - Breed: to.StringPtr("persian"), - Color: to.StringPtr("green"), - Hates: &[]complexgroup.Dog{ - { - Food: to.StringPtr("tomato"), - ID: to.Int32Ptr(1), - Name: to.StringPtr("Potato"), + Cat: complexgroup.Cat{ + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(2), + Name: to.StringPtr("Siameeee"), }, - { - Food: to.StringPtr("french fries"), - ID: to.Int32Ptr(-1), - Name: to.StringPtr("Tomato"), + Color: to.StringPtr("green"), + Hates: &[]complexgroup.Dog{ + { + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(1), + Name: to.StringPtr("Potato"), + }, + Food: to.StringPtr("tomato"), + }, + { + Pet: complexgroup.Pet{ + ID: to.Int32Ptr(-1), + Name: to.StringPtr("Tomato"), + }, + Food: to.StringPtr("french fries"), + }, }, }, - ID: to.Int32Ptr(2), - Name: to.StringPtr("Siameeee"), + Breed: to.StringPtr("persian"), }) if err != nil { t.Fatalf("PutValid: %v", err) diff --git a/test/autorest/complexgroup/polymorphicrecursive_test.go b/test/autorest/complexgroup/polymorphicrecursive_test.go index 77f1b5204..527ac8562 100644 --- a/test/autorest/complexgroup/polymorphicrecursive_test.go +++ b/test/autorest/complexgroup/polymorphicrecursive_test.go @@ -32,64 +32,84 @@ func TestGetValid(t *testing.T) { sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) helpers.DeepEqualOrFatal(t, result.Fish, &complexgroup.Salmon{ - Fishtype: to.StringPtr("salmon"), - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Salmon{ - Fishtype: to.StringPtr("salmon"), - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(2), - Location: to.StringPtr("atlantic"), + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("salmon"), + Length: to.Float32Ptr(1), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), + &complexgroup.Salmon{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("salmon"), + Length: to.Float32Ptr(2), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator"), + }, + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, + }, + }, + Species: to.StringPtr("coho"), + }, + Iswild: to.BoolPtr(true), + Location: to.StringPtr("atlantic"), }, &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Siblings: &[]complexgroup.FishClassification{}, + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, }, }, - Species: to.StringPtr("coho"), + Species: to.StringPtr("predator"), }, - &complexgroup.Sawshark{ + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Siblings: &[]complexgroup.FishClassification{}, + Species: to.StringPtr("dangerous"), + }, Age: to.Int32Ptr(105), Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Siblings: &[]complexgroup.FishClassification{}, - Species: to.StringPtr("dangerous"), }, + Picture: &[]byte{255, 255, 255, 255, 254}, }, - Species: to.StringPtr("predator"), - }, - &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Siblings: &[]complexgroup.FishClassification{}, - Species: to.StringPtr("dangerous"), }, + Species: to.StringPtr("king"), }, - Species: to.StringPtr("king"), + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), }) } @@ -99,64 +119,84 @@ func TestPutValid(t *testing.T) { sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) result, err := client.PutValid(context.Background(), &complexgroup.Salmon{ - Fishtype: to.StringPtr("salmon"), - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Salmon{ - Fishtype: to.StringPtr("salmon"), - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(2), - Location: to.StringPtr("atlantic"), + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("salmon"), + Length: to.Float32Ptr(1), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), + &complexgroup.Salmon{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("salmon"), + Length: to.Float32Ptr(2), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator"), + }, + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, + }, + }, + Species: to.StringPtr("coho"), + }, + Iswild: to.BoolPtr(true), + Location: to.StringPtr("atlantic"), }, &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Siblings: &[]complexgroup.FishClassification{}, + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, }, }, - Species: to.StringPtr("coho"), + Species: to.StringPtr("predator"), }, - &complexgroup.Sawshark{ + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Siblings: &[]complexgroup.FishClassification{}, + Species: to.StringPtr("dangerous"), + }, Age: to.Int32Ptr(105), Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Siblings: &[]complexgroup.FishClassification{}, - Species: to.StringPtr("dangerous"), }, + Picture: &[]byte{255, 255, 255, 255, 254}, }, - Species: to.StringPtr("predator"), - }, - &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Siblings: &[]complexgroup.FishClassification{}, - Species: to.StringPtr("dangerous"), }, + Species: to.StringPtr("king"), }, - Species: to.StringPtr("king"), + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), }) if err != nil { t.Fatal(err) diff --git a/test/autorest/complexgroup/polymorphism_test.go b/test/autorest/complexgroup/polymorphism_test.go index 3d307ed64..3521fecb8 100644 --- a/test/autorest/complexgroup/polymorphism_test.go +++ b/test/autorest/complexgroup/polymorphism_test.go @@ -36,39 +36,56 @@ func TestPolymorphismGetComplicated(t *testing.T) { goblinBday := time.Date(2015, time.August, 8, 0, 0, 0, 0, time.UTC) sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) - helpers.DeepEqualOrFatal(t, salmon, &complexgroup.SmartSalmon{ + expectedFish := complexgroup.Fish{ Fishtype: to.StringPtr("smart_salmon"), - Iswild: to.BoolPtr(true), Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), Siblings: &[]complexgroup.FishClassification{ &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator")}, Age: to.Int32Ptr(6), Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), }, &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, }, &complexgroup.Goblinshark{ - Age: to.Int32Ptr(1), - Birthday: &goblinBday, - Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), - Fishtype: to.StringPtr("goblin"), - Jawsize: to.Int32Ptr(5), - Length: to.Float32Ptr(30), - Species: to.StringPtr("scary"), + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("goblin"), + Length: to.Float32Ptr(30), + Species: to.StringPtr("scary"), + }, + Age: to.Int32Ptr(1), + Birthday: &goblinBday, + }, + Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), + Jawsize: to.Int32Ptr(5), }, }, Species: to.StringPtr("king"), + } + expectedSalmon := complexgroup.Salmon{ + Fish: expectedFish, + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), + } + helpers.DeepEqualOrFatal(t, salmon, &complexgroup.SmartSalmon{ + Salmon: expectedSalmon, }) + helpers.DeepEqualOrFatal(t, result.Salmon.GetSalmon(), &expectedSalmon) + helpers.DeepEqualOrFatal(t, result.Salmon.GetFish(), &expectedFish) } // GetComposedWithDiscriminator - Get complex object composing a polymorphic scalar property and array property with polymorphic element type, with discriminator specified. Deserialization must NOT fail and use the discriminator type specified on the wire. @@ -81,43 +98,55 @@ func TestPolymorphismGetComposedWithDiscriminator(t *testing.T) { helpers.DeepEqualOrFatal(t, result.DotFishMarket, &complexgroup.DotFishMarket{ Fishes: &[]complexgroup.DotFishClassification{ &complexgroup.DotSalmon{ - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("australia"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, &complexgroup.DotSalmon{ - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("canada"), Iswild: to.BoolPtr(true), - Species: to.StringPtr("king"), }, }, Salmons: &[]complexgroup.DotSalmon{ { - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("sweden"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, { - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("atlantic"), Iswild: to.BoolPtr(true), - Species: to.StringPtr("king"), }, }, SampleFish: &complexgroup.DotSalmon{ - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("australia"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, SampleSalmon: &complexgroup.DotSalmon{ - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("sweden"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, }) } @@ -140,23 +169,29 @@ func TestPolymorphismGetComposedWithoutDiscriminator(t *testing.T) { }, Salmons: &[]complexgroup.DotSalmon{ { + DotFish: complexgroup.DotFish{ + Species: to.StringPtr("king"), + }, Location: to.StringPtr("sweden"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, { + DotFish: complexgroup.DotFish{ + Species: to.StringPtr("king"), + }, Location: to.StringPtr("atlantic"), Iswild: to.BoolPtr(true), - Species: to.StringPtr("king"), }, }, SampleFish: &complexgroup.DotFish{ Species: to.StringPtr("king"), }, SampleSalmon: &complexgroup.DotSalmon{ + DotFish: complexgroup.DotFish{ + Species: to.StringPtr("king"), + }, Location: to.StringPtr("sweden"), Iswild: to.BoolPtr(false), - Species: to.StringPtr("king"), }, }) } @@ -169,10 +204,12 @@ func TestPolymorphismGetDotSyntax(t *testing.T) { t.Fatal(err) } helpers.DeepEqualOrFatal(t, result.DotFish, &complexgroup.DotSalmon{ - FishType: to.StringPtr("DotSalmon"), + DotFish: complexgroup.DotFish{ + FishType: to.StringPtr("DotSalmon"), + Species: to.StringPtr("king"), + }, Location: to.StringPtr("sweden"), Iswild: to.BoolPtr(true), - Species: to.StringPtr("king"), }) } @@ -191,37 +228,49 @@ func TestPolymorphismGetValid(t *testing.T) { sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) helpers.DeepEqualOrFatal(t, salmon, &complexgroup.Salmon{ - Fishtype: to.StringPtr("salmon"), - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Fishtype: to.StringPtr("shark"), - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), - }, - &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Fishtype: to.StringPtr("sawshark"), - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), - }, - &complexgroup.Goblinshark{ - Age: to.Int32Ptr(1), - Birthday: &goblinBday, - Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), - Fishtype: to.StringPtr("goblin"), - Jawsize: to.Int32Ptr(5), - Length: to.Float32Ptr(30), - Species: to.StringPtr("scary"), + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("salmon"), + Length: to.Float32Ptr(1), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("shark"), + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator"), + }, + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("sawshark"), + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, + }, + &complexgroup.Goblinshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Fishtype: to.StringPtr("goblin"), + Length: to.Float32Ptr(30), + Species: to.StringPtr("scary"), + }, + Age: to.Int32Ptr(1), + Birthday: &goblinBday, + }, + Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), + Jawsize: to.Int32Ptr(5), + }, }, + Species: to.StringPtr("king"), }, - Species: to.StringPtr("king"), + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), }) } @@ -237,33 +286,45 @@ func TestPolymorphismPutMissingDiscriminator(t *testing.T) { sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) result, err := client.PutMissingDiscriminator(context.Background(), &complexgroup.Salmon{ - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), - }, - &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), - }, - &complexgroup.Goblinshark{ - Age: to.Int32Ptr(1), - Birthday: &goblinBday, - Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), - Jawsize: to.Int32Ptr(5), - Length: to.Float32Ptr(30), - Species: to.StringPtr("scary"), + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(1), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator"), + }, + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, + }, + &complexgroup.Goblinshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(30), + Species: to.StringPtr("scary"), + }, + Age: to.Int32Ptr(1), + Birthday: &goblinBday, + }, + Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), + Jawsize: to.Int32Ptr(5), + }, }, + Species: to.StringPtr("king"), }, - Species: to.StringPtr("king"), + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), }) if err != nil { t.Fatal(err) @@ -278,33 +339,45 @@ func TestPolymorphismPutValid(t *testing.T) { sawBday := time.Date(1900, time.January, 5, 1, 0, 0, 0, time.UTC) sharkBday := time.Date(2012, time.January, 5, 1, 0, 0, 0, time.UTC) resp, err := client.PutValid(context.Background(), &complexgroup.Salmon{ - Iswild: to.BoolPtr(true), - Length: to.Float32Ptr(1), - Location: to.StringPtr("alaska"), - Siblings: &[]complexgroup.FishClassification{ - &complexgroup.Shark{ - Age: to.Int32Ptr(6), - Birthday: &sharkBday, - Length: to.Float32Ptr(20), - Species: to.StringPtr("predator"), - }, - &complexgroup.Sawshark{ - Age: to.Int32Ptr(105), - Birthday: &sawBday, - Length: to.Float32Ptr(10), - Picture: &[]byte{255, 255, 255, 255, 254}, - Species: to.StringPtr("dangerous"), - }, - &complexgroup.Goblinshark{ - Age: to.Int32Ptr(1), - Birthday: &goblinBday, - Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), - Jawsize: to.Int32Ptr(5), - Length: to.Float32Ptr(30), - Species: to.StringPtr("scary"), + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(1), + Siblings: &[]complexgroup.FishClassification{ + &complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(20), + Species: to.StringPtr("predator"), + }, + Age: to.Int32Ptr(6), + Birthday: &sharkBday, + }, + &complexgroup.Sawshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(10), + Species: to.StringPtr("dangerous"), + }, + Age: to.Int32Ptr(105), + Birthday: &sawBday, + }, + Picture: &[]byte{255, 255, 255, 255, 254}, + }, + &complexgroup.Goblinshark{ + Shark: complexgroup.Shark{ + Fish: complexgroup.Fish{ + Length: to.Float32Ptr(30), + Species: to.StringPtr("scary"), + }, + Age: to.Int32Ptr(1), + Birthday: &goblinBday, + }, + Color: complexgroup.GoblinSharkColor("pinkish-gray").ToPtr(), + Jawsize: to.Int32Ptr(5), + }, }, + Species: to.StringPtr("king"), }, - Species: to.StringPtr("king"), + Iswild: to.BoolPtr(true), + Location: to.StringPtr("alaska"), }) if err != nil { t.Fatal(err) diff --git a/test/autorest/generated/complexgroup/models.go b/test/autorest/generated/complexgroup/models.go index f66b37ba0..6ba36aa12 100644 --- a/test/autorest/generated/complexgroup/models.go +++ b/test/autorest/generated/complexgroup/models.go @@ -68,80 +68,20 @@ type ByteWrapperResponse struct { } type Cat struct { + Pet Color *string `json:"color,omitempty"` Hates *[]Dog `json:"hates,omitempty"` - ID *int32 `json:"id,omitempty"` - Name *string `json:"name,omitempty"` } type Cookiecuttershark struct { - Age *int32 `json:"age,omitempty"` - Birthday *time.Time `json:"birthday,omitempty"` - Fishtype *string `json:"fishtype,omitempty"` - Length *float32 `json:"length,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Shark } func (c Cookiecuttershark) MarshalJSON() ([]byte, error) { - c.Fishtype = strptr(fishClassificationCookiecuttershark) - type alias Cookiecuttershark - aux := &struct { - *alias - Birthday *timeRFC3339 `json:"birthday"` - }{ - alias: (*alias)(&c), - Birthday: (*timeRFC3339)(c.Birthday), - } - return json.Marshal(aux) + objectMap := c.Shark.marshalInternal("cookiecuttershark") + return json.Marshal(objectMap) } -func (c *Cookiecuttershark) UnmarshalJSON(data []byte) error { - var rawMsg map[string]*json.RawMessage - if err := json.Unmarshal(data, &rawMsg); err != nil { - return err - } - for k, v := range rawMsg { - var err error - switch k { - case "age": - if v != nil { - err = json.Unmarshal(*v, &c.Age) - } - case "birthday": - if v != nil { - var aux timeRFC3339 - err = json.Unmarshal(*v, &aux) - c.Birthday = (*time.Time)(&aux) - } - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &c.Fishtype) - } - case "length": - if v != nil { - err = json.Unmarshal(*v, &c.Length) - } - case "siblings": - if v != nil { - c.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &c.Species) - } - } - if err != nil { - return err - } - } - return nil -} - -func (*Cookiecuttershark) fishClassification() {} - -func (*Cookiecuttershark) sharkClassification() {} - type DateWrapper struct { Field *time.Time `json:"field,omitempty"` Leap *time.Time `json:"leap,omitempty"` @@ -257,14 +197,13 @@ type DictionaryWrapperResponse struct { } type Dog struct { + Pet Food *string `json:"food,omitempty"` - ID *int32 `json:"id,omitempty"` - Name *string `json:"name,omitempty"` } // DotFishClassification provides polymorphic access to related types. type DotFishClassification interface { - dotFishClassification() + GetDotFish() *DotFish } type DotFish struct { @@ -272,7 +211,41 @@ type DotFish struct { Species *string `json:"species,omitempty"` } -func (*DotFish) dotFishClassification() {} +func (d *DotFish) GetDotFish() *DotFish { return d } + +func (d *DotFish) UnmarshalJSON(data []byte) error { + var rawMsg map[string]*json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return err + } + for k, v := range rawMsg { + var err error + switch k { + case "fish.type": + if v != nil { + err = json.Unmarshal(*v, &d.FishType) + } + case "species": + if v != nil { + err = json.Unmarshal(*v, &d.Species) + } + } + if err != nil { + return err + } + } + return nil +} + +func (d DotFish) marshalInternal(discValue string) map[string]interface{} { + objectMap := make(map[string]interface{}) + d.FishType = &discValue + objectMap["fish.type"] = d.FishType + if d.Species != nil { + objectMap["species"] = d.Species + } + return objectMap +} type DotFishMarket struct { Fishes *[]DotFishClassification `json:"fishes,omitempty"` @@ -339,21 +312,20 @@ func (d *DotFishResponse) UnmarshalJSON(data []byte) error { } type DotSalmon struct { - FishType *string `json:"fish.type,omitempty"` + DotFish Iswild *bool `json:"iswild,omitempty"` Location *string `json:"location,omitempty"` - Species *string `json:"species,omitempty"` } func (d DotSalmon) MarshalJSON() ([]byte, error) { - d.FishType = strptr(dotFishClassificationDotSalmon) - type alias DotSalmon - aux := &struct { - *alias - }{ - alias: (*alias)(&d), + objectMap := d.DotFish.marshalInternal("DotSalmon") + if d.Iswild != nil { + objectMap["iswild"] = d.Iswild } - return json.Marshal(aux) + if d.Location != nil { + objectMap["location"] = d.Location + } + return json.Marshal(objectMap) } func (d *DotSalmon) UnmarshalJSON(data []byte) error { @@ -364,10 +336,6 @@ func (d *DotSalmon) UnmarshalJSON(data []byte) error { for k, v := range rawMsg { var err error switch k { - case "fish.type": - if v != nil { - err = json.Unmarshal(*v, &d.FishType) - } case "iswild": if v != nil { err = json.Unmarshal(*v, &d.Iswild) @@ -376,20 +344,14 @@ func (d *DotSalmon) UnmarshalJSON(data []byte) error { if v != nil { err = json.Unmarshal(*v, &d.Location) } - case "species": - if v != nil { - err = json.Unmarshal(*v, &d.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &d.DotFish) } -func (*DotSalmon) dotFishClassification() {} - type DoubleWrapper struct { Field1 *float64 `json:"field1,omitempty"` Field56ZerosAfterTheDotAndNegativeZeroBeforeDotAndThisIsALongFieldNameOnPurpose *float64 `json:"field_56_zeros_after_the_dot_and_negative_zero_before_dot_and_this_is_a_long_field_name_on_purpose,omitempty"` @@ -436,7 +398,7 @@ func (e Error) Error() string { // FishClassification provides polymorphic access to related types. type FishClassification interface { - fishClassification() + GetFish() *Fish } type Fish struct { @@ -446,7 +408,55 @@ type Fish struct { Species *string `json:"species,omitempty"` } -func (*Fish) fishClassification() {} +func (f *Fish) GetFish() *Fish { return f } + +func (f *Fish) UnmarshalJSON(data []byte) error { + var rawMsg map[string]*json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return err + } + for k, v := range rawMsg { + var err error + switch k { + case "fishtype": + if v != nil { + err = json.Unmarshal(*v, &f.Fishtype) + } + case "length": + if v != nil { + err = json.Unmarshal(*v, &f.Length) + } + case "siblings": + if v != nil { + f.Siblings, err = unmarshalFishClassificationArray(*v) + } + case "species": + if v != nil { + err = json.Unmarshal(*v, &f.Species) + } + } + if err != nil { + return err + } + } + return nil +} + +func (f Fish) marshalInternal(discValue string) map[string]interface{} { + objectMap := make(map[string]interface{}) + f.Fishtype = &discValue + objectMap["fishtype"] = f.Fishtype + if f.Length != nil { + objectMap["length"] = f.Length + } + if f.Siblings != nil { + objectMap["siblings"] = f.Siblings + } + if f.Species != nil { + objectMap["species"] = f.Species + } + return objectMap +} // FishResponse is the response envelope for operations that return a Fish type. type FishResponse struct { @@ -479,29 +489,21 @@ type FloatWrapperResponse struct { } type Goblinshark struct { - Age *int32 `json:"age,omitempty"` - Birthday *time.Time `json:"birthday,omitempty"` - + Shark // Colors possible - Color *GoblinSharkColor `json:"color,omitempty"` - Fishtype *string `json:"fishtype,omitempty"` - Jawsize *int32 `json:"jawsize,omitempty"` - Length *float32 `json:"length,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Color *GoblinSharkColor `json:"color,omitempty"` + Jawsize *int32 `json:"jawsize,omitempty"` } func (g Goblinshark) MarshalJSON() ([]byte, error) { - g.Fishtype = strptr(fishClassificationGoblin) - type alias Goblinshark - aux := &struct { - *alias - Birthday *timeRFC3339 `json:"birthday"` - }{ - alias: (*alias)(&g), - Birthday: (*timeRFC3339)(g.Birthday), + objectMap := g.Shark.marshalInternal("goblin") + if g.Color != nil { + objectMap["color"] = g.Color } - return json.Marshal(aux) + if g.Jawsize != nil { + objectMap["jawsize"] = g.Jawsize + } + return json.Marshal(objectMap) } func (g *Goblinshark) UnmarshalJSON(data []byte) error { @@ -512,52 +514,22 @@ func (g *Goblinshark) UnmarshalJSON(data []byte) error { for k, v := range rawMsg { var err error switch k { - case "age": - if v != nil { - err = json.Unmarshal(*v, &g.Age) - } - case "birthday": - if v != nil { - var aux timeRFC3339 - err = json.Unmarshal(*v, &aux) - g.Birthday = (*time.Time)(&aux) - } case "color": if v != nil { err = json.Unmarshal(*v, &g.Color) } - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &g.Fishtype) - } case "jawsize": if v != nil { err = json.Unmarshal(*v, &g.Jawsize) } - case "length": - if v != nil { - err = json.Unmarshal(*v, &g.Length) - } - case "siblings": - if v != nil { - g.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &g.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &g.Shark) } -func (*Goblinshark) fishClassification() {} - -func (*Goblinshark) sharkClassification() {} - type IntWrapper struct { Field1 *int32 `json:"field1,omitempty"` Field2 *int32 `json:"field2,omitempty"` @@ -590,7 +562,7 @@ type MyBaseHelperType struct { // MyBaseTypeClassification provides polymorphic access to related types. type MyBaseTypeClassification interface { - myBaseTypeClassification() + GetMyBaseType() *MyBaseType } type MyBaseType struct { @@ -599,7 +571,48 @@ type MyBaseType struct { PropB1 *string `json:"propB1,omitempty"` } -func (*MyBaseType) myBaseTypeClassification() {} +func (m *MyBaseType) GetMyBaseType() *MyBaseType { return m } + +func (m *MyBaseType) UnmarshalJSON(data []byte) error { + var rawMsg map[string]*json.RawMessage + if err := json.Unmarshal(data, &rawMsg); err != nil { + return err + } + for k, v := range rawMsg { + var err error + switch k { + case "helper": + if v != nil { + err = json.Unmarshal(*v, &m.Helper) + } + case "kind": + if v != nil { + err = json.Unmarshal(*v, &m.Kind) + } + case "propB1": + if v != nil { + err = json.Unmarshal(*v, &m.PropB1) + } + } + if err != nil { + return err + } + } + return nil +} + +func (m MyBaseType) marshalInternal(discValue string) map[string]interface{} { + objectMap := make(map[string]interface{}) + if m.Helper != nil { + objectMap["helper"] = m.Helper + } + m.Kind = &discValue + objectMap["kind"] = m.Kind + if m.PropB1 != nil { + objectMap["propB1"] = m.PropB1 + } + return objectMap +} // MyBaseTypeResponse is the response envelope for operations that return a MyBaseType type. type MyBaseTypeResponse struct { @@ -619,21 +632,16 @@ func (m *MyBaseTypeResponse) UnmarshalJSON(data []byte) error { } type MyDerivedType struct { - Helper *MyBaseHelperType `json:"helper,omitempty"` - Kind *string `json:"kind,omitempty"` - PropB1 *string `json:"propB1,omitempty"` - PropD1 *string `json:"propD1,omitempty"` + MyBaseType + PropD1 *string `json:"propD1,omitempty"` } func (m MyDerivedType) MarshalJSON() ([]byte, error) { - m.Kind = strptr(myBaseTypeClassificationKind1) - type alias MyDerivedType - aux := &struct { - *alias - }{ - alias: (*alias)(&m), + objectMap := m.MyBaseType.marshalInternal("Kind1") + if m.PropD1 != nil { + objectMap["propD1"] = m.PropD1 } - return json.Marshal(aux) + return json.Marshal(objectMap) } func (m *MyDerivedType) UnmarshalJSON(data []byte) error { @@ -644,18 +652,6 @@ func (m *MyDerivedType) UnmarshalJSON(data []byte) error { for k, v := range rawMsg { var err error switch k { - case "helper": - if v != nil { - err = json.Unmarshal(*v, &m.Helper) - } - case "kind": - if v != nil { - err = json.Unmarshal(*v, &m.Kind) - } - case "propB1": - if v != nil { - err = json.Unmarshal(*v, &m.PropB1) - } case "propD1": if v != nil { err = json.Unmarshal(*v, &m.PropD1) @@ -665,11 +661,9 @@ func (m *MyDerivedType) UnmarshalJSON(data []byte) error { return err } } - return nil + return json.Unmarshal(data, &m.MyBaseType) } -func (*MyDerivedType) myBaseTypeClassification() {} - type Pet struct { ID *int32 `json:"id,omitempty"` Name *string `json:"name,omitempty"` @@ -690,27 +684,26 @@ type ReadonlyObjResponse struct { // SalmonClassification provides polymorphic access to related types. type SalmonClassification interface { FishClassification - salmonClassification() + GetSalmon() *Salmon } type Salmon struct { - Fishtype *string `json:"fishtype,omitempty"` - Iswild *bool `json:"iswild,omitempty"` - Length *float32 `json:"length,omitempty"` - Location *string `json:"location,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Fish + Iswild *bool `json:"iswild,omitempty"` + Location *string `json:"location,omitempty"` } +func (s *Salmon) GetSalmon() *Salmon { return s } + func (s Salmon) MarshalJSON() ([]byte, error) { - s.Fishtype = strptr(fishClassificationSalmon) - type alias Salmon - aux := &struct { - *alias - }{ - alias: (*alias)(&s), + objectMap := s.Fish.marshalInternal("salmon") + if s.Iswild != nil { + objectMap["iswild"] = s.Iswild } - return json.Marshal(aux) + if s.Location != nil { + objectMap["location"] = s.Location + } + return json.Marshal(objectMap) } func (s *Salmon) UnmarshalJSON(data []byte) error { @@ -721,41 +714,32 @@ func (s *Salmon) UnmarshalJSON(data []byte) error { for k, v := range rawMsg { var err error switch k { - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &s.Fishtype) - } case "iswild": if v != nil { err = json.Unmarshal(*v, &s.Iswild) } - case "length": - if v != nil { - err = json.Unmarshal(*v, &s.Length) - } case "location": if v != nil { err = json.Unmarshal(*v, &s.Location) } - case "siblings": - if v != nil { - s.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &s.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &s.Fish) } -func (*Salmon) fishClassification() {} - -func (*Salmon) salmonClassification() {} +func (s Salmon) marshalInternal(discValue string) map[string]interface{} { + objectMap := s.Fish.marshalInternal(discValue) + if s.Iswild != nil { + objectMap["iswild"] = s.Iswild + } + if s.Location != nil { + objectMap["location"] = s.Location + } + return objectMap +} // SalmonResponse is the response envelope for operations that return a Salmon type. type SalmonResponse struct { @@ -774,26 +758,16 @@ func (s *SalmonResponse) UnmarshalJSON(data []byte) error { } type Sawshark struct { - Age *int32 `json:"age,omitempty"` - Birthday *time.Time `json:"birthday,omitempty"` - Fishtype *string `json:"fishtype,omitempty"` - Length *float32 `json:"length,omitempty"` - Picture *[]byte `json:"picture,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Shark + Picture *[]byte `json:"picture,omitempty"` } func (s Sawshark) MarshalJSON() ([]byte, error) { - s.Fishtype = strptr(fishClassificationSawshark) - type alias Sawshark - aux := &struct { - *alias - Birthday *timeRFC3339 `json:"birthday"` - }{ - alias: (*alias)(&s), - Birthday: (*timeRFC3339)(s.Birthday), + objectMap := s.Shark.marshalInternal("sawshark") + if s.Picture != nil { + objectMap["picture"] = s.Picture } - return json.Marshal(aux) + return json.Marshal(objectMap) } func (s *Sawshark) UnmarshalJSON(data []byte) error { @@ -804,74 +778,41 @@ func (s *Sawshark) UnmarshalJSON(data []byte) error { for k, v := range rawMsg { var err error switch k { - case "age": - if v != nil { - err = json.Unmarshal(*v, &s.Age) - } - case "birthday": - if v != nil { - var aux timeRFC3339 - err = json.Unmarshal(*v, &aux) - s.Birthday = (*time.Time)(&aux) - } - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &s.Fishtype) - } - case "length": - if v != nil { - err = json.Unmarshal(*v, &s.Length) - } case "picture": if v != nil { err = json.Unmarshal(*v, &s.Picture) } - case "siblings": - if v != nil { - s.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &s.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &s.Shark) } -func (*Sawshark) fishClassification() {} - -func (*Sawshark) sharkClassification() {} - // SharkClassification provides polymorphic access to related types. type SharkClassification interface { FishClassification - sharkClassification() + GetShark() *Shark } type Shark struct { - Age *int32 `json:"age,omitempty"` - Birthday *time.Time `json:"birthday,omitempty"` - Fishtype *string `json:"fishtype,omitempty"` - Length *float32 `json:"length,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Fish + Age *int32 `json:"age,omitempty"` + Birthday *time.Time `json:"birthday,omitempty"` } +func (s *Shark) GetShark() *Shark { return s } + func (s Shark) MarshalJSON() ([]byte, error) { - s.Fishtype = strptr(fishClassificationShark) - type alias Shark - aux := &struct { - *alias - Birthday *timeRFC3339 `json:"birthday"` - }{ - alias: (*alias)(&s), - Birthday: (*timeRFC3339)(s.Birthday), + objectMap := s.Fish.marshalInternal("shark") + if s.Age != nil { + objectMap["age"] = s.Age } - return json.Marshal(aux) + if s.Birthday != nil { + objectMap["birthday"] = (*timeRFC3339)(s.Birthday) + } + return json.Marshal(objectMap) } func (s *Shark) UnmarshalJSON(data []byte) error { @@ -892,40 +833,28 @@ func (s *Shark) UnmarshalJSON(data []byte) error { err = json.Unmarshal(*v, &aux) s.Birthday = (*time.Time)(&aux) } - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &s.Fishtype) - } - case "length": - if v != nil { - err = json.Unmarshal(*v, &s.Length) - } - case "siblings": - if v != nil { - s.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &s.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &s.Fish) } -func (*Shark) fishClassification() {} - -func (*Shark) sharkClassification() {} +func (s Shark) marshalInternal(discValue string) map[string]interface{} { + objectMap := s.Fish.marshalInternal(discValue) + if s.Age != nil { + objectMap["age"] = s.Age + } + if s.Birthday != nil { + objectMap["birthday"] = (*timeRFC3339)(s.Birthday) + } + return objectMap +} type Siamese struct { + Cat Breed *string `json:"breed,omitempty"` - Color *string `json:"color,omitempty"` - Hates *[]Dog `json:"hates,omitempty"` - ID *int32 `json:"id,omitempty"` - Name *string `json:"name,omitempty"` } // SiameseResponse is the response envelope for operations that return a Siamese type. @@ -936,24 +865,16 @@ type SiameseResponse struct { } type SmartSalmon struct { - CollegeDegree *string `json:"college_degree,omitempty"` - Fishtype *string `json:"fishtype,omitempty"` - Iswild *bool `json:"iswild,omitempty"` - Length *float32 `json:"length,omitempty"` - Location *string `json:"location,omitempty"` - Siblings *[]FishClassification `json:"siblings,omitempty"` - Species *string `json:"species,omitempty"` + Salmon + CollegeDegree *string `json:"college_degree,omitempty"` } func (s SmartSalmon) MarshalJSON() ([]byte, error) { - s.Fishtype = strptr(fishClassificationSmartSalmon) - type alias SmartSalmon - aux := &struct { - *alias - }{ - alias: (*alias)(&s), + objectMap := s.Salmon.marshalInternal("smart_salmon") + if s.CollegeDegree != nil { + objectMap["college_degree"] = s.CollegeDegree } - return json.Marshal(aux) + return json.Marshal(objectMap) } func (s *SmartSalmon) UnmarshalJSON(data []byte) error { @@ -968,42 +889,14 @@ func (s *SmartSalmon) UnmarshalJSON(data []byte) error { if v != nil { err = json.Unmarshal(*v, &s.CollegeDegree) } - case "fishtype": - if v != nil { - err = json.Unmarshal(*v, &s.Fishtype) - } - case "iswild": - if v != nil { - err = json.Unmarshal(*v, &s.Iswild) - } - case "length": - if v != nil { - err = json.Unmarshal(*v, &s.Length) - } - case "location": - if v != nil { - err = json.Unmarshal(*v, &s.Location) - } - case "siblings": - if v != nil { - s.Siblings, err = unmarshalFishClassificationArray(*v) - } - case "species": - if v != nil { - err = json.Unmarshal(*v, &s.Species) - } } if err != nil { return err } } - return nil + return json.Unmarshal(data, &s.Salmon) } -func (*SmartSalmon) fishClassification() {} - -func (*SmartSalmon) salmonClassification() {} - type StringWrapper struct { Empty *string `json:"empty,omitempty"` Field *string `json:"field,omitempty"` diff --git a/test/autorest/generated/complexgroup/polymorphic_helpers.go b/test/autorest/generated/complexgroup/polymorphic_helpers.go index e0a2f4382..2856811b3 100644 --- a/test/autorest/generated/complexgroup/polymorphic_helpers.go +++ b/test/autorest/generated/complexgroup/polymorphic_helpers.go @@ -7,10 +7,6 @@ package complexgroup import "encoding/json" -const ( - dotFishClassificationDotSalmon = "DotSalmon" -) - func unmarshalDotFishClassification(body []byte) (DotFishClassification, error) { var m map[string]interface{} if err := json.Unmarshal(body, &m); err != nil { @@ -18,7 +14,7 @@ func unmarshalDotFishClassification(body []byte) (DotFishClassification, error) } var b DotFishClassification switch m["fish.type"] { - case dotFishClassificationDotSalmon: + case "DotSalmon": b = &DotSalmon{} default: b = &DotFish{} @@ -42,15 +38,6 @@ func unmarshalDotFishClassificationArray(body []byte) (*[]DotFishClassification, return &fArray, nil } -const ( - fishClassificationCookiecuttershark = "cookiecuttershark" - fishClassificationGoblin = "goblin" - fishClassificationSalmon = "salmon" - fishClassificationSawshark = "sawshark" - fishClassificationShark = "shark" - fishClassificationSmartSalmon = "smart_salmon" -) - func unmarshalFishClassification(body []byte) (FishClassification, error) { var m map[string]interface{} if err := json.Unmarshal(body, &m); err != nil { @@ -58,17 +45,17 @@ func unmarshalFishClassification(body []byte) (FishClassification, error) { } var b FishClassification switch m["fishtype"] { - case fishClassificationCookiecuttershark: + case "cookiecuttershark": b = &Cookiecuttershark{} - case fishClassificationGoblin: + case "goblin": b = &Goblinshark{} - case fishClassificationSalmon: + case "salmon": b = &Salmon{} - case fishClassificationSawshark: + case "sawshark": b = &Sawshark{} - case fishClassificationShark: + case "shark": b = &Shark{} - case fishClassificationSmartSalmon: + case "smart_salmon": b = &SmartSalmon{} default: b = &Fish{} @@ -99,7 +86,7 @@ func unmarshalSalmonClassification(body []byte) (SalmonClassification, error) { } var b SalmonClassification switch m["fishtype"] { - case fishClassificationSmartSalmon: + case "smart_salmon": b = &SmartSalmon{} default: b = &Salmon{} @@ -130,11 +117,11 @@ func unmarshalSharkClassification(body []byte) (SharkClassification, error) { } var b SharkClassification switch m["fishtype"] { - case fishClassificationCookiecuttershark: + case "cookiecuttershark": b = &Cookiecuttershark{} - case fishClassificationGoblin: + case "goblin": b = &Goblinshark{} - case fishClassificationSawshark: + case "sawshark": b = &Sawshark{} default: b = &Shark{} @@ -158,10 +145,6 @@ func unmarshalSharkClassificationArray(body []byte) (*[]SharkClassification, err return &fArray, nil } -const ( - myBaseTypeClassificationKind1 = "Kind1" -) - func unmarshalMyBaseTypeClassification(body []byte) (MyBaseTypeClassification, error) { var m map[string]interface{} if err := json.Unmarshal(body, &m); err != nil { @@ -169,7 +152,7 @@ func unmarshalMyBaseTypeClassification(body []byte) (MyBaseTypeClassification, e } var b MyBaseTypeClassification switch m["kind"] { - case myBaseTypeClassificationKind1: + case "Kind1": b = &MyDerivedType{} default: b = &MyBaseType{} diff --git a/test/autorest/generated/httpinfrastructuregroup/models.go b/test/autorest/generated/httpinfrastructuregroup/models.go index d54686f19..898ac0858 100644 --- a/test/autorest/generated/httpinfrastructuregroup/models.go +++ b/test/autorest/generated/httpinfrastructuregroup/models.go @@ -11,7 +11,7 @@ import ( ) type B struct { - StatusCode *string `json:"statusCode,omitempty"` + MyException TextStatusCode *string `json:"textStatusCode,omitempty"` } diff --git a/test/autorest/generated/lrogroup/models.go b/test/autorest/generated/lrogroup/models.go index 6e3832db9..400a00349 100644 --- a/test/autorest/generated/lrogroup/models.go +++ b/test/autorest/generated/lrogroup/models.go @@ -759,21 +759,8 @@ type OperationResultError struct { } type Product struct { - // Resource Id - ID *string `json:"id,omitempty"` - - // Resource Location - Location *string `json:"location,omitempty"` - - // Resource Name - Name *string `json:"name,omitempty"` + Resource Properties *ProductProperties `json:"properties,omitempty"` - - // Dictionary of - Tags *map[string]string `json:"tags,omitempty"` - - // Resource Type - Type *string `json:"type,omitempty"` } type ProductProperties struct { @@ -819,8 +806,7 @@ type SkuResponse struct { } type SubProduct struct { - // Sub Resource Id - ID *string `json:"id,omitempty"` + SubResource Properties *SubProductProperties `json:"properties,omitempty"` }