-
Notifications
You must be signed in to change notification settings - Fork 54
/
sqlstruct.go
281 lines (237 loc) · 7.53 KB
/
sqlstruct.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// Copyright 2012 Kamil Kisiel. All rights reserved.
// Use of this source code is governed by the MIT
// license which can be found in the LICENSE file.
/*
Package sqlstruct provides some convenience functions for using structs with
the Go standard library's database/sql package.
The package matches struct field names to SQL query column names. A field can
also specify a matching column with "sql" tag, if it's different from field
name. Unexported fields or fields marked with `sql:"-"` are ignored, just like
with "encoding/json" package.
For example:
type T struct {
F1 string
F2 string `sql:"field2"`
F3 string `sql:"-"`
}
rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{})))
...
for rows.Next() {
var t T
err = sqlstruct.Scan(&t, rows)
...
}
err = rows.Err() // get any errors encountered during iteration
Aliased tables in a SQL statement may be scanned into a specific structure identified
by the same alias, using the ColumnsAliased and ScanAliased functions:
type User struct {
Id int `sql:"id"`
Username string `sql:"username"`
Email string `sql:"address"`
Name string `sql:"name"`
HomeAddress *Address `sql:"-"`
}
type Address struct {
Id int `sql:"id"`
City string `sql:"city"`
Street string `sql:"address"`
}
...
var user User
var address Address
sql := `
SELECT %s, %s FROM users AS u
INNER JOIN address AS a ON a.id = u.address_id
WHERE u.username = ?
`
sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a"))
rows, err := db.Query(sql, "gedi")
if err != nil {
log.Fatal(err)
}
defer rows.Close()
if rows.Next() {
err = sqlstruct.ScanAliased(&user, rows, "u")
if err != nil {
log.Fatal(err)
}
err = sqlstruct.ScanAliased(&address, rows, "a")
if err != nil {
log.Fatal(err)
}
user.HomeAddress = address
}
fmt.Printf("%+v", *user)
// output: "{Id:1 Username:gedi Email:[email protected] Name:Gedas HomeAddress:0xc21001f570}"
fmt.Printf("%+v", *user.HomeAddress)
// output: "{Id:2 City:Vilnius Street:Plento 34}"
*/
package sqlstruct
import (
"bytes"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"sync"
)
// NameMapper is the function used to convert struct fields which do not have sql tags
// into database column names.
//
// The default mapper converts field names to lower case. If instead you would prefer
// field names converted to snake case, simply assign sqlstruct.ToSnakeCase to the variable:
//
// sqlstruct.NameMapper = sqlstruct.ToSnakeCase
//
// Alternatively for a custom mapping, any func(string) string can be used instead.
var NameMapper func(string) string = strings.ToLower
// A cache of fieldInfos to save reflecting every time. Inspried by encoding/xml
var finfos map[reflect.Type]fieldInfo
var finfoLock sync.RWMutex
// TagName is the name of the tag to use on struct fields
var TagName = "sql"
// fieldInfo is a mapping of field tag values to their indices
type fieldInfo map[string][]int
func init() {
finfos = make(map[reflect.Type]fieldInfo)
}
// Rows defines the interface of types that are scannable with the Scan function.
// It is implemented by the sql.Rows type from the standard library
type Rows interface {
Scan(...interface{}) error
Columns() ([]string, error)
}
// getFieldInfo creates a fieldInfo for the provided type. Fields that are not tagged
// with the "sql" tag and unexported fields are not included.
func getFieldInfo(typ reflect.Type) fieldInfo {
finfoLock.RLock()
finfo, ok := finfos[typ]
finfoLock.RUnlock()
if ok {
return finfo
}
finfo = make(fieldInfo)
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
tag := f.Tag.Get(TagName)
// Skip unexported fields or fields marked with "-"
if f.PkgPath != "" || tag == "-" {
continue
}
// Handle embedded structs
if f.Anonymous && f.Type.Kind() == reflect.Struct {
for k, v := range getFieldInfo(f.Type) {
finfo[k] = append([]int{i}, v...)
}
continue
}
// Use field name for untagged fields
if tag == "" {
tag = f.Name
}
tag = NameMapper(tag)
finfo[tag] = []int{i}
}
finfoLock.Lock()
finfos[typ] = finfo
finfoLock.Unlock()
return finfo
}
// Scan scans the next row from rows in to a struct pointed to by dest. The struct type
// should have exported fields tagged with the "sql" tag. Columns from row which are not
// mapped to any struct fields are ignored. Struct fields which have no matching column
// in the result set are left unchanged.
func Scan(dest interface{}, rows Rows) error {
return doScan(dest, rows, "")
}
// ScanAliased works like scan, except that it expects the results in the query to be
// prefixed by the given alias.
//
// For example, if scanning to a field named "name" with an alias of "user" it will
// expect to find the result in a column named "user_name".
//
// See ColumnAliased for a convenient way to generate these queries.
func ScanAliased(dest interface{}, rows Rows, alias string) error {
return doScan(dest, rows, alias)
}
// Columns returns a string containing a sorted, comma-separated list of column names as
// defined by the type s. s must be a struct that has exported fields tagged with the "sql" tag.
func Columns(s interface{}) string {
return strings.Join(cols(s), ", ")
}
// ColumnsAliased works like Columns except it prefixes the resulting column name with the
// given alias.
//
// For each field in the given struct it will generate a statement like:
// alias.field AS alias_field
//
// It is intended to be used in conjunction with the ScanAliased function.
func ColumnsAliased(s interface{}, alias string) string {
names := cols(s)
aliased := make([]string, 0, len(names))
for _, n := range names {
aliased = append(aliased, alias+"."+n+" AS "+alias+"_"+n)
}
return strings.Join(aliased, ", ")
}
func cols(s interface{}) []string {
v := reflect.ValueOf(s)
fields := getFieldInfo(v.Type())
names := make([]string, 0, len(fields))
for f := range fields {
names = append(names, f)
}
sort.Strings(names)
return names
}
func doScan(dest interface{}, rows Rows, alias string) error {
destv := reflect.ValueOf(dest)
typ := destv.Type()
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
panic(fmt.Errorf("dest must be pointer to struct; got %T", destv))
}
fieldInfo := getFieldInfo(typ.Elem())
elem := destv.Elem()
var values []interface{}
cols, err := rows.Columns()
if err != nil {
return err
}
for _, name := range cols {
if len(alias) > 0 {
name = strings.Replace(name, alias+"_", "", 1)
}
idx, ok := fieldInfo[strings.ToLower(name)]
var v interface{}
if !ok {
// There is no field mapped to this column so we discard it
v = &sql.RawBytes{}
} else {
v = elem.FieldByIndex(idx).Addr().Interface()
}
values = append(values, v)
}
return rows.Scan(values...)
}
// ToSnakeCase converts a string to snake case, words separated with underscores.
// It's intended to be used with NameMapper to map struct field names to snake case database fields.
func ToSnakeCase(src string) string {
thisUpper := false
prevUpper := false
buf := bytes.NewBufferString("")
for i, v := range src {
if v >= 'A' && v <= 'Z' {
thisUpper = true
} else {
thisUpper = false
}
if i > 0 && thisUpper && !prevUpper {
buf.WriteRune('_')
}
prevUpper = thisUpper
buf.WriteRune(v)
}
return strings.ToLower(buf.String())
}