-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.go
113 lines (104 loc) · 2.45 KB
/
helper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package gosl
import (
"context"
"errors"
"fmt"
"reflect"
"time"
"unsafe"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)
type Key int
var STACK Key = 101
// ConnectToDB simple wrapper for db connection with sqlx
func ConnectToDB(user, password, host, port, name string, maxOpen, maxIdle int, maxLifetime, maxIdleLifetime time.Duration) *sqlx.DB {
db := sqlx.MustConnect("mysql", fmt.Sprintf(
"%s:%s@(%s:%s)/%s?parseTime=true",
user,
password,
host,
port,
name))
db.SetMaxOpenConns(maxOpen)
db.SetMaxIdleConns(maxIdle)
db.SetConnMaxLifetime(maxLifetime)
db.SetConnMaxIdleTime(maxIdleLifetime)
return db
}
// RunInTransaction db wrapper for transaction
func RunInTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
var err error
callCount, ok := ctx.Value(STACK).(int)
trxs := make([]*sqlx.Tx, 0)
var currKey interface{}
var injectTx func(curr interface{})
injectTx = func(curr interface{}) {
values := reflect.ValueOf(curr).Elem()
keys := reflect.TypeOf(curr).Elem()
if keys.Kind() == reflect.Struct {
for i := 0; i < values.NumField(); i++ {
value := values.Field(i)
value = reflect.NewAt(value.Type(), unsafe.Pointer(value.UnsafeAddr())).Elem()
field := keys.Field(i)
if field.Name == "key" {
currKey = value.Interface()
} else if field.Name != "Context" {
q, ok := value.Interface().(*Queryable)
if !ok {
if tmp := value.Interface(); tmp != nil {
kind := reflect.TypeOf(tmp).Kind()
if kind == reflect.Pointer {
injectTx(tmp)
}
}
continue
}
if nil == q.db {
err = errors.New("no active db con")
return
}
if nil == q.tx {
var tx *sqlx.Tx
tx, err = q.db.Beginx()
if nil != err {
return
}
trxs = append(trxs, tx)
con := make(map[string]interface{})
con["db"] = q.db
con["tx"] = tx
ctx = context.WithValue(ctx, currKey, NewQueryable(con))
}
}
}
}
}
if !ok {
callCount = 0
injectTx(ctx)
if err != nil {
return err
}
}
ctx = context.WithValue(ctx, STACK, callCount+1)
err = fn(ctx)
if nil != err {
for _, trx := range trxs {
_ = trx.Rollback()
}
return err
}
if callCount == 0 {
for _, trx := range trxs {
err := trx.Commit()
if err != nil {
for _, itrx := range trxs {
_ = itrx.Rollback()
}
return fmt.Errorf("error when committing transaction: %v", err)
}
}
}
return nil
}