diff --git a/_example/sql_str_int.go b/_example/sql_str_int.go index e7753b9..70a2bae 100644 --- a/_example/sql_str_int.go +++ b/_example/sql_str_int.go @@ -4,3 +4,6 @@ package example // ENUM(_,zeus, apollo, athena=20, ares) type GreekGod string + +// ENUM(_,zeus, apollo, _=19, athena="20", ares) +type GreekGodCustom string diff --git a/_example/sql_str_int_enum.go b/_example/sql_str_int_enum.go index 627d062..43ee773 100644 --- a/_example/sql_str_int_enum.go +++ b/_example/sql_str_int_enum.go @@ -206,3 +206,198 @@ func (x NullGreekGod) Value() (driver.Value, error) { // driver.Value accepts int64 for int values. return string(x.GreekGod), nil } + +const ( + // Skipped value. + _ GreekGodCustom = "_" + // GreekGodCustomZeus is a GreekGodCustom of type zeus. + GreekGodCustomZeus GreekGodCustom = "zeus" + // GreekGodCustomApollo is a GreekGodCustom of type apollo. + GreekGodCustomApollo GreekGodCustom = "apollo" + // Skipped value. + _ GreekGodCustom = "_" + // GreekGodCustomAthena is a GreekGodCustom of type athena. + GreekGodCustomAthena GreekGodCustom = "20" + // GreekGodCustomAres is a GreekGodCustom of type ares. + GreekGodCustomAres GreekGodCustom = "ares" +) + +var ErrInvalidGreekGodCustom = fmt.Errorf("not a valid GreekGodCustom, try [%s]", strings.Join(_GreekGodCustomNames, ", ")) + +var _GreekGodCustomNames = []string{ + string(GreekGodCustomZeus), + string(GreekGodCustomApollo), + string(GreekGodCustomAthena), + string(GreekGodCustomAres), +} + +// GreekGodCustomNames returns a list of possible string values of GreekGodCustom. +func GreekGodCustomNames() []string { + tmp := make([]string, len(_GreekGodCustomNames)) + copy(tmp, _GreekGodCustomNames) + return tmp +} + +// String implements the Stringer interface. +func (x GreekGodCustom) String() string { + return string(x) +} + +// String implements the Stringer interface. +func (x GreekGodCustom) IsValid() bool { + _, err := ParseGreekGodCustom(string(x)) + return err == nil +} + +var _GreekGodCustomValue = map[string]GreekGodCustom{ + "zeus": GreekGodCustomZeus, + "apollo": GreekGodCustomApollo, + "20": GreekGodCustomAthena, + "ares": GreekGodCustomAres, +} + +// ParseGreekGodCustom attempts to convert a string to a GreekGodCustom. +func ParseGreekGodCustom(name string) (GreekGodCustom, error) { + if x, ok := _GreekGodCustomValue[name]; ok { + return x, nil + } + return GreekGodCustom(""), fmt.Errorf("%s is %w", name, ErrInvalidGreekGodCustom) +} + +var errGreekGodCustomNilPtr = errors.New("value pointer is nil") // one per type for package clashes + +var sqlIntGreekGodCustomMap = map[int64]GreekGodCustom{ + 1: GreekGodCustomZeus, + 2: GreekGodCustomApollo, + 20: GreekGodCustomAthena, + 21: GreekGodCustomAres, +} + +var sqlIntGreekGodCustomValue = map[GreekGodCustom]int64{ + GreekGodCustomZeus: 1, + GreekGodCustomApollo: 2, + GreekGodCustomAthena: 20, + GreekGodCustomAres: 21, +} + +func lookupSqlIntGreekGodCustom(val int64) (GreekGodCustom, error) { + x, ok := sqlIntGreekGodCustomMap[val] + if !ok { + return x, fmt.Errorf("%v is not %w", val, ErrInvalidGreekGodCustom) + } + return x, nil +} + +// Scan implements the Scanner interface. +func (x *GreekGodCustom) Scan(value interface{}) (err error) { + if value == nil { + *x = GreekGodCustom("") + return + } + + // A wider range of scannable types. + // driver.Value values at the top of the list for expediency + switch v := value.(type) { + case int64: + *x, err = lookupSqlIntGreekGodCustom(v) + case string: + *x, err = ParseGreekGodCustom(v) + case []byte: + if val, verr := strconv.ParseInt(string(v), 10, 64); verr == nil { + *x, err = lookupSqlIntGreekGodCustom(val) + } else { + // try parsing the value as a string + *x, err = ParseGreekGodCustom(string(v)) + } + case GreekGodCustom: + *x = v + case int: + *x, err = lookupSqlIntGreekGodCustom(int64(v)) + case *GreekGodCustom: + if v == nil { + return errGreekGodCustomNilPtr + } + *x = *v + case uint: + *x, err = lookupSqlIntGreekGodCustom(int64(v)) + case uint64: + *x, err = lookupSqlIntGreekGodCustom(int64(v)) + case *int: + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = lookupSqlIntGreekGodCustom(int64(*v)) + case *int64: + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = lookupSqlIntGreekGodCustom(int64(*v)) + case float64: // json marshals everything as a float64 if it's a number + *x, err = lookupSqlIntGreekGodCustom(int64(v)) + case *float64: // json marshals everything as a float64 if it's a number + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = lookupSqlIntGreekGodCustom(int64(*v)) + case *uint: + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = lookupSqlIntGreekGodCustom(int64(*v)) + case *uint64: + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = lookupSqlIntGreekGodCustom(int64(*v)) + case *string: + if v == nil { + return errGreekGodCustomNilPtr + } + *x, err = ParseGreekGodCustom(*v) + default: + return errors.New("invalid type for GreekGodCustom") + } + + return +} + +// Value implements the driver Valuer interface. +func (x GreekGodCustom) Value() (driver.Value, error) { + val, ok := sqlIntGreekGodCustomValue[x] + if !ok { + return nil, ErrInvalidGreekGodCustom + } + return int64(val), nil +} + +type NullGreekGodCustom struct { + GreekGodCustom GreekGodCustom + Valid bool +} + +func NewNullGreekGodCustom(val interface{}) (x NullGreekGodCustom) { + err := x.Scan(val) // yes, we ignore this error, it will just be an invalid value. + _ = err // make any errcheck linters happy + return +} + +// Scan implements the Scanner interface. +func (x *NullGreekGodCustom) Scan(value interface{}) (err error) { + if value == nil { + x.GreekGodCustom, x.Valid = GreekGodCustom(""), false + return + } + + err = x.GreekGodCustom.Scan(value) + x.Valid = (err == nil) + return +} + +// Value implements the driver Valuer interface. +func (x NullGreekGodCustom) Value() (driver.Value, error) { + if !x.Valid { + return nil, nil + } + // driver.Value accepts int64 for int values. + return string(x.GreekGodCustom), nil +} diff --git a/_example/strings_only.go b/_example/strings_only.go index a1472c2..e296a71 100644 --- a/_example/strings_only.go +++ b/_example/strings_only.go @@ -2,5 +2,5 @@ package example -// ENUM(pending, running, completed, failed) +// ENUM(pending, running, completed, failed=error) type StrState string diff --git a/_example/strings_only_enum.go b/_example/strings_only_enum.go index f181ee7..2cf8984 100644 --- a/_example/strings_only_enum.go +++ b/_example/strings_only_enum.go @@ -18,7 +18,7 @@ const ( StrStatePending StrState = "pending" StrStateRunning StrState = "running" StrStateCompleted StrState = "completed" - StrStateFailed StrState = "failed" + StrStateFailed StrState = "error" ) var ErrInvalidStrState = fmt.Errorf("not a valid StrState, try [%s]", strings.Join(_StrStateNames, ", ")) @@ -62,7 +62,7 @@ var _StrStateValue = map[string]StrState{ "pending": StrStatePending, "running": StrStateRunning, "completed": StrStateCompleted, - "failed": StrStateFailed, + "error": StrStateFailed, } // ParseStrState attempts to convert a string to a StrState. diff --git a/_example/strings_only_test.go b/_example/strings_only_test.go index 2155b6c..bdc0b34 100644 --- a/_example/strings_only_test.go +++ b/_example/strings_only_test.go @@ -20,7 +20,7 @@ func TestStrState(t *testing.T) { func TestStrStateMustParse(t *testing.T) { x := `avocado` - assert.PanicsWithError(t, x+" is not a valid StrState, try [pending, running, completed, failed]", func() { MustParseStrState(x) }) + assert.PanicsWithError(t, x+" is not a valid StrState, try [pending, running, completed, error]", func() { MustParseStrState(x) }) assert.NotPanics(t, func() { MustParseStrState(StrStateFailed.String()) }) } @@ -78,14 +78,14 @@ func TestStrStateUnmarshal(t *testing.T) { }, { name: "failed", - input: `{"state":"Failed"}`, + input: `{"state":"Error"}`, output: &testData{StrStateX: StrStateFailed}, errorExpected: false, err: nil, }, { name: "failedlower", - input: `{"state":"failed"}`, + input: `{"state":"error"}`, output: &testData{StrStateX: StrStateFailed}, errorExpected: false, err: nil, @@ -141,7 +141,7 @@ func TestStrStateMarshal(t *testing.T) { }, { name: "green", - output: `{"state":"failed"}`, + output: `{"state":"error"}`, input: &testData{StrStateX: StrStateFailed}, errorExpected: false, err: nil, diff --git a/generator/enum_string.tmpl b/generator/enum_string.tmpl index 3c7e65d..299fc0e 100644 --- a/generator/enum_string.tmpl +++ b/generator/enum_string.tmpl @@ -10,7 +10,7 @@ const ( {{- if $value.Comment}} // {{$value.Comment}} {{- end}} - {{$value.PrefixedName}} {{$enumName}} = "{{$value.RawName}}" + {{$value.PrefixedName}} {{$enumName}} = "{{$value.ValueStr}}" {{- end}} ) {{if .names -}} @@ -100,7 +100,7 @@ func (x *{{.enum.Name}}) UnmarshalText(text []byte) error { } {{end}} -{{ if or .sql .sqlnullstr .sqlint .sqlnullint }} +{{ if .anySQLEnabled }} var err{{.enum.Name}}NilPtr = errors.New("value pointer is nil") // one per type for package clashes {{ end }} @@ -136,8 +136,8 @@ func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { default: return errors.New("invalid type for {{.enum.Name}}") } - - return + + return } // Value implements the driver Valuer interface. @@ -149,12 +149,12 @@ func (x {{.enum.Name}}) Value() (driver.Value, error) { {{/* SQL stored as an integer value */}} {{ if or .sqlint .sqlnullint }} var sqlInt{{.enum.Name}}Map = map[int64]{{.enum.Name}}{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} -{{ $value.Value }}: {{ $value.PrefixedName }},{{end}} +{{ $value.ValueInt }}: {{ $value.PrefixedName }},{{end}} {{- end}} } var sqlInt{{.enum.Name}}Value = map[{{.enum.Name}}]int64{ {{ range $rIndex, $value := .enum.Values }}{{ if ne $value.Name "_"}} - {{ $value.PrefixedName }}: {{ $value.Value }},{{end}} + {{ $value.PrefixedName }}: {{ $value.ValueInt }},{{end}} {{- end}} } @@ -170,7 +170,7 @@ func lookupSqlInt{{.enum.Name}}(val int64) ({{.enum.Name}}, error){ func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { if value == nil { *x = {{.enum.Name}}("") - return + return } // A wider range of scannable types. @@ -235,8 +235,8 @@ func (x *{{.enum.Name}}) Scan(value interface{}) (err error) { default: return errors.New("invalid type for {{.enum.Name}}") } - - return + + return } // Value implements the driver Valuer interface. diff --git a/generator/generator.go b/generator/generator.go index 7845fcd..975d1e0 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -69,7 +69,8 @@ type EnumValue struct { RawName string Name string PrefixedName string - Value interface{} + ValueStr string + ValueInt interface{} Comment string } @@ -208,6 +209,10 @@ func (g *Generator) WithNoComments() *Generator { return g } +func (g *Generator) anySQLEnabled() bool { + return g.sql || g.sqlNullStr || g.sqlint || g.sqlNullInt +} + // ParseAliases is used to add aliases to replace during name sanitization. func ParseAliases(aliases []string) error { aliasMap := map[string]string{} @@ -292,22 +297,23 @@ func (g *Generator) Generate(f *ast.File) ([]byte, error) { created++ data := map[string]interface{}{ - "enum": enum, - "name": name, - "lowercase": g.lowercaseLookup, - "nocase": g.caseInsensitive, - "nocomments": g.noComments, - "marshal": g.marshal, - "sql": g.sql, - "sqlint": g.sqlint, - "flag": g.flag, - "names": g.names, - "values": g.values, - "ptr": g.ptr, - "sqlnullint": g.sqlNullInt, - "sqlnullstr": g.sqlNullStr, - "mustparse": g.mustParse, - "forcelower": g.forceLower, + "enum": enum, + "name": name, + "lowercase": g.lowercaseLookup, + "nocase": g.caseInsensitive, + "nocomments": g.noComments, + "marshal": g.marshal, + "sql": g.sql, + "sqlint": g.sqlint, + "flag": g.flag, + "names": g.names, + "ptr": g.ptr, + "values": g.values, + "anySQLEnabled": g.anySQLEnabled(), + "sqlnullint": g.sqlNullInt, + "sqlnullstr": g.sqlNullStr, + "mustparse": g.mustParse, + "forcelower": g.forceLower, } templateName := "enum" @@ -401,12 +407,25 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { // Make sure to leave out any empty parts if value != "" { + rawName := value + valueStr := value + if strings.Contains(value, `=`) { // Get the value specified and set the data to that value. equalIndex := strings.Index(value, `=`) dataVal := strings.TrimSpace(value[equalIndex+1:]) if dataVal != "" { - if unsigned { + valueStr = dataVal + rawName = value[:equalIndex] + if enum.Type == "string" { + if parsed, err := strconv.ParseInt(dataVal, 10, 64); err == nil { + data = parsed + valueStr = rawName + } + if isQuoted(dataVal) { + valueStr = trimQuotes(dataVal) + } + } else if unsigned { newData, err := strconv.ParseUint(dataVal, 10, 64) if err != nil { err = errors.Wrapf(err, "failed parsing the data part of enum value '%s'", value) @@ -423,13 +442,13 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { } data = newData } - value = value[:equalIndex] } else { - value = strings.TrimSuffix(value, `=`) - fmt.Printf("Ignoring enum with '=' but no value after: %s\n", value) + rawName = strings.TrimSuffix(rawName, `=`) + fmt.Printf("Ignoring enum with '=' but no value after: %s\n", rawName) } } - rawName := strings.TrimSpace(value) + rawName = strings.TrimSpace(rawName) + valueStr = strings.TrimSpace(valueStr) name := cases.Title(language.Und, cases.NoLower).String(rawName) prefixedName := name if name != skipHolder { @@ -440,7 +459,7 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { } } - ev := EnumValue{Name: name, RawName: rawName, PrefixedName: prefixedName, Value: data, Comment: comment} + ev := EnumValue{Name: name, RawName: rawName, PrefixedName: prefixedName, ValueStr: valueStr, ValueInt: data, Comment: comment} enum.Values = append(enum.Values, ev) data = increment(data) } @@ -451,6 +470,20 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { return enum, nil } +func isQuoted(s string) bool { + s = strings.TrimSpace(s) + return (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) || (strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) +} + +func trimQuotes(s string) string { + s = strings.TrimSpace(s) + for _, quote := range []string{`"`, `'`} { + s = strings.TrimPrefix(s, quote) + s = strings.TrimSuffix(s, quote) + } + return s +} + func increment(d interface{}) interface{} { switch v := d.(type) { case uint64: diff --git a/generator/generator_test.go b/generator/generator_test.go index 0bcf6eb..50e3b00 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -341,3 +341,28 @@ func TestParenthesesParsing(t *testing.T) { fmt.Println(string(output)) } } + +// TestQuotedStrings +func TestQuotedStrings(t *testing.T) { + input := `package test + // This is a pre-enum comment that needs (to be handled properly) + // ENUM( + // abc (x), + // ghi = "20", + //). This is an extra string comment (With parentheses of it's own) + // And (another line) with Parentheses + type Animal string + ` + g := NewGenerator() + f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) + assert.Nil(t, err, "Error parsing no struct input") + + output, err := g.Generate(f) + assert.Nil(t, err, "Error generating formatted code") + assert.Contains(t, string(output), "// AnimalAbcX is a Animal of type abc (x).") + assert.Contains(t, string(output), "AnimalGhi Animal = \"20\"") + assert.NotContains(t, string(output), "// AnimalAnd") + if false { // Debugging statement + fmt.Println(string(output)) + } +} diff --git a/generator/template_funcs.go b/generator/template_funcs.go index 493862c..ebacd2b 100644 --- a/generator/template_funcs.go +++ b/generator/template_funcs.go @@ -67,12 +67,12 @@ func UnmapifyStringEnum(e Enum, lowercase bool) (ret string, err error) { } for _, val := range e.Values { if val.Name != skipHolder { - _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", val.RawName, val.PrefixedName)) + _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", val.ValueStr, val.PrefixedName)) if err != nil { return } - if lowercase && strings.ToLower(val.RawName) != val.RawName { - _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", strings.ToLower(val.RawName), val.PrefixedName)) + if lowercase && strings.ToLower(val.ValueStr) != val.ValueStr { + _, err = builder.WriteString(fmt.Sprintf("%q:%s,\n", strings.ToLower(val.ValueStr), val.PrefixedName)) if err != nil { return } @@ -118,9 +118,9 @@ func namifyStringEnum(e Enum) (ret string, err error) { func Offset(index int, enumType string, val EnumValue) (strResult string) { if strings.HasPrefix(enumType, "u") { // Unsigned - return strconv.FormatUint(val.Value.(uint64)-uint64(index), 10) + return strconv.FormatUint(val.ValueInt.(uint64)-uint64(index), 10) } else { // Signed - return strconv.FormatInt(val.Value.(int64)-int64(index), 10) + return strconv.FormatInt(val.ValueInt.(int64)-int64(index), 10) } }