Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sms: init #48

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,6 @@ github.com/urfave/cli/v2 v2.23.5 h1:xbrU7tAYviSpqeR3X4nEFWUdB/uDZ6DE+HxmRU7Xtyw=
github.com/urfave/cli/v2 v2.23.5/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
github.com/wuhan005/gadget v0.0.0-20221206194113-7619e407f1a0 h1:zOXiOJRG/FOohTliJiykpwIaCPtUTIh+G0jw2bOJkA8=
github.com/wuhan005/gadget v0.0.0-20221206194113-7619e407f1a0/go.mod h1:vmC2IdgzTpIRwn1ZpuV/I3k9AIbRJ7oqTHFenq/qwkE=
github.com/wuhan005/govalid v0.0.0-20220315191209-043a899c3c7a h1:9vhVeLzwzrFm/pGinLXh2zCSDRO7ElnawEG4p527itQ=
github.com/wuhan005/govalid v0.0.0-20220315191209-043a899c3c7a/go.mod h1:zRrIdMbJM3Xe4lmXyrUi2xF9CE0+D4Y0OpQIMpjC0Vo=
github.com/wuhan005/govalid v0.0.0-20230216091828-820aa255fd21 h1:EHaQ4hLfjckhbI+AEleDHeb0cEN1bIAfvWJEhCe7e2Y=
github.com/wuhan005/govalid v0.0.0-20230216091828-820aa255fd21/go.mod h1:zRrIdMbJM3Xe4lmXyrUi2xF9CE0+D4Y0OpQIMpjC0Vo=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ func runWeb(ctx *cli.Context) error {
logrus.WarnLevel,
)))

_, err := db.Init()
db, err := db.Init()
if err != nil {
return errors.Wrap(err, "connect to database")
}

logrus.WithContext(ctx.Context).WithField("external_url", conf.App.ExternalURL).Info("Starting web server")
r := route.New()
r := route.New(db)
r.Use(tracing.Middleware("NekoBox"))
r.Run(conf.Server.Port)

Expand Down
4 changes: 4 additions & 0 deletions internal/conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,9 @@ func Init() error {
return errors.Wrap(err, "map 'mail'")
}

if err := File.Section("sms").MapTo(&SMS); err != nil {
return errors.Wrap(err, "map 'sms'")
}

return nil
}
8 changes: 8 additions & 0 deletions internal/conf/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,12 @@ var (
Port int `ini:"port"`
SMTP string `ini:"smtp"`
}

SMS struct {
AliyunRegion string `ini:"aliyun_region"`
AliyunAccessKey string `ini:"aliyun_access_key"`
AliyunAccessKeySecret string `ini:"aliyun_access_key_secret"`
AliyunSignName string `ini:"aliyun_sign_name"`
AliyunTemplateCode string `ini:"aliyun_template_code"`
}
)
20 changes: 19 additions & 1 deletion internal/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ import (
"github.com/unknwon/com"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"

"github.com/NekoWheel/NekoBox/internal/conf"
"github.com/NekoWheel/NekoBox/internal/db"
"github.com/NekoWheel/NekoBox/internal/dbutil"
"github.com/NekoWheel/NekoBox/internal/security/sms"
templatepkg "github.com/NekoWheel/NekoBox/internal/template"
)

Expand Down Expand Up @@ -149,7 +152,7 @@ func (c *Context) JSONError(errorCode int, message string) error {
}

// Contexter initializes a classic context for a request.
func Contexter() flamego.Handler {
func Contexter(gormDB *gorm.DB) flamego.Handler {
return func(ctx flamego.Context, data template.Data, session session.Session, x csrf.CSRF, t template.Template, flash session.Flash) {
c := Context{
Context: ctx,
Expand Down Expand Up @@ -218,6 +221,21 @@ func Contexter() flamego.Handler {
c.ResponseWriter().Header().Set("X-Content-Type-Options", "nosniff")
c.ResponseWriter().Header().Set("X-Frame-Options", "DENY")

var smsModule sms.SMS
if conf.SMS.AliyunSignName != "" && conf.SMS.AliyunTemplateCode != "" {
smsModule = sms.NewAliyunSMS(sms.NewAliyunSMSOptions{
Region: conf.SMS.AliyunRegion,
AccessKey: conf.SMS.AliyunAccessKey,
AccessKeySecret: conf.SMS.AliyunAccessKeySecret,
SignName: conf.SMS.AliyunSignName,
TemplateCode: conf.SMS.AliyunTemplateCode,
})
} else {
smsModule = sms.NewDummySMS()
}
ctx.Map(smsModule)

c.MapTo(gormDB, (*dbutil.Transactor)(nil))
ctx.Map(c)
ctx.Map(EndpointWeb)
}
Expand Down
50 changes: 50 additions & 0 deletions internal/db/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"gorm.io/gorm"

"github.com/NekoWheel/NekoBox/internal/conf"
"github.com/NekoWheel/NekoBox/internal/dbutil"
)

var Users UsersStore
Expand All @@ -23,8 +24,10 @@ type UsersStore interface {
GetByID(ctx context.Context, id uint) (*User, error)
GetByEmail(ctx context.Context, email string) (*User, error)
GetByDomain(ctx context.Context, domain string) (*User, error)
GetByPhone(ctx context.Context, phone string) (*User, error)
Update(ctx context.Context, id uint, opts UpdateUserOptions) error
UpdateHarassmentSetting(ctx context.Context, id uint, typ HarassmentSettingType) error
UpdateVerifyType(ctx context.Context, id uint, verifyType VerifyType) error
Authenticate(ctx context.Context, email, password string) (*User, error)
ChangePassword(ctx context.Context, id uint, oldPassword, newPassword string) error
UpdatePassword(ctx context.Context, id uint, newPassword string) error
Expand All @@ -39,16 +42,33 @@ type users struct {
*gorm.DB
}

type VerifyType uint

func (v VerifyType) IsValid() bool {
return v >= VerifyTypeUnverified && v <= VerifyTypeVerified
}

func (v VerifyType) IsUnverified() bool {
return v == VerifyTypeUnverified
}

const (
VerifyTypeUnverified VerifyType = iota
VerifyTypeVerified
)

type User struct {
gorm.Model `json:"-"`
Name string `json:"name"`
Password string `json:"-"`
Email string `json:"email"`
Phone string `json:"-" gorm:"size:50;uniqueIndex:user_phone_unique_idx; default: NULL"`
Avatar string `json:"avatar"`
Domain string `json:"domain"`
Background string `json:"background"`
Intro string `json:"intro"`
Notify NotifyType `json:"notify"`
VerifyType VerifyType `json:"-"`
HarassmentSetting HarassmentSettingType `json:"harassment_setting"`
}

Expand Down Expand Up @@ -79,17 +99,20 @@ type CreateUserOptions struct {
Name string
Password string
Email string
Phone string
Avatar string
Domain string
Background string
Intro string
VerifyType VerifyType
}

var (
ErrUserNotExists = errors.New("账号不存在")
ErrBadCredential = errors.New("邮箱或密码错误")
ErrDuplicateEmail = errors.New("这个邮箱已经注册过账号了!")
ErrDuplicateDomain = errors.New("个性域名重复了,换一个吧~")
ErrDuplicatePhone = errors.New("这个手机号已经注册过账号了!")
)

func (db *users) Create(ctx context.Context, opts CreateUserOptions) error {
Expand All @@ -101,15 +124,20 @@ func (db *users) Create(ctx context.Context, opts CreateUserOptions) error {
Name: opts.Name,
Password: opts.Password,
Email: opts.Email,
Phone: opts.Phone,
Avatar: opts.Avatar,
Domain: opts.Domain,
Background: opts.Background,
Intro: opts.Intro,
VerifyType: opts.VerifyType,
Notify: NotifyTypeEmail,
}
newUser.EncodePassword()

if err := db.WithContext(ctx).Create(newUser).Error; err != nil {
if dbutil.IsUniqueViolation(err, "users.user_phone_unique_idx") {
return ErrDuplicatePhone
}
return errors.Wrap(err, "create user")
}
return nil
Expand Down Expand Up @@ -138,8 +166,13 @@ func (db *users) GetByDomain(ctx context.Context, domain string) (*User, error)
return db.getBy(ctx, "domain = ?", domain)
}

func (db *users) GetByPhone(ctx context.Context, phone string) (*User, error) {
return db.getBy(ctx, "phone = ?", phone)
}

type UpdateUserOptions struct {
Name string
Phone string
Avatar string
Background string
Intro string
Expand All @@ -163,8 +196,12 @@ func (db *users) Update(ctx context.Context, id uint, opts UpdateUserOptions) er
Avatar: opts.Avatar,
Background: opts.Background,
Intro: opts.Intro,
Phone: opts.Phone,
Notify: opts.Notify,
}).Error; err != nil {
if dbutil.IsUniqueViolation(err, "users.user_phone_unique_idx") {
return ErrDuplicatePhone
}
return errors.Wrap(err, "update user")
}
return nil
Expand All @@ -185,6 +222,19 @@ func (db *users) UpdateHarassmentSetting(ctx context.Context, id uint, typ Haras
return nil
}

func (db *users) UpdateVerifyType(ctx context.Context, id uint, verifyType VerifyType) error {
if !verifyType.IsValid() {
return errors.Errorf("unexpected verify type: %q", verifyType)
}

if err := db.WithContext(ctx).Where("id = ?", id).Updates(&User{
VerifyType: verifyType,
}).Error; err != nil {
return errors.Wrap(err, "update user")
}
return nil
}

func (db *users) Authenticate(ctx context.Context, email, password string) (*User, error) {
u, err := db.GetByEmail(ctx, email)
if err != nil {
Expand Down
73 changes: 73 additions & 0 deletions internal/db/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ func TestUsers(t *testing.T) {
{"GetByID", testUsersGetByID},
{"GetByEmail", testUsersGetByEmail},
{"GetByDomain", testUsersGetByDomain},
{"GetByPhone", testUsersGetByPhone},
{"Update", testUsersUpdate},
{"UpdateHarassmentSetting", testUsersUpdateHarassmentSetting},
{"UpdateVerifyType", testUsersUpdateVerifyType},
{"Authenticate", testUsersAuthenticate},
{"ChangePassword", testUsersChangePassword},
{"UpdatePassword", testUsersUpdatePassword},
Expand Down Expand Up @@ -214,6 +216,50 @@ func testUsersGetByDomain(t *testing.T, ctx context.Context, db *users) {
})
}

func testUsersGetByPhone(t *testing.T, ctx context.Context, db *users) {
err := db.Create(ctx, CreateUserOptions{
Name: "E99p1ant",
Password: "super_secret",
Email: "[email protected]",
Phone: "13800138000",
Avatar: "avater.png",
Domain: "e99",
Background: "background.png",
Intro: "Be cool, but also be warm.",
})
require.Nil(t, err)

t.Run("normal", func(t *testing.T) {
got, err := db.GetByPhone(ctx, "13800138000")
require.Nil(t, err)

got.CreatedAt = time.Time{}
got.UpdatedAt = time.Time{}

want := &User{
Model: gorm.Model{
ID: 1,
},
Name: "E99p1ant",
Password: "super_secret",
Email: "[email protected]",
Phone: "13800138000",
Avatar: "avater.png",
Domain: "e99",
Background: "background.png",
Intro: "Be cool, but also be warm.",
Notify: NotifyTypeEmail,
}
want.EncodePassword()
require.Equal(t, want, got)
})

t.Run("not found", func(t *testing.T) {
_, err := db.GetByPhone(ctx, "404")
require.Equal(t, ErrUserNotExists, err)
})
}

func testUsersUpdate(t *testing.T, ctx context.Context, db *users) {
err := db.Create(ctx, CreateUserOptions{
Name: "E99p1ant",
Expand Down Expand Up @@ -283,6 +329,33 @@ func testUsersUpdateHarassmentSetting(t *testing.T, ctx context.Context, db *use
})
}

func testUsersUpdateVerifyType(t *testing.T, ctx context.Context, db *users) {
err := db.Create(ctx, CreateUserOptions{
Name: "E99p1ant",
Password: "super_secret",
Email: "[email protected]",
Avatar: "avater.png",
Domain: "e99",
Background: "background.png",
Intro: "Be cool, but also be warm.",
})
require.Nil(t, err)

t.Run("normal", func(t *testing.T) {
err := db.UpdateVerifyType(ctx, 1, VerifyTypeVerified)
require.Nil(t, err)

got, err := db.GetByID(ctx, 1)
require.Nil(t, err)
require.Equal(t, VerifyTypeVerified, got.VerifyType)
})

t.Run("unexpected verify type", func(t *testing.T) {
err := db.UpdateVerifyType(ctx, 1, 404)
require.NotNil(t, err)
})
}

func testUsersAuthenticate(t *testing.T, ctx context.Context, db *users) {
err := db.Create(ctx, CreateUserOptions{
Name: "E99p1ant",
Expand Down
11 changes: 11 additions & 0 deletions internal/dbutil/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dbutil

import (
"database/sql"

"gorm.io/gorm"
)

type Transactor interface {
Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) (err error)
}
38 changes: 38 additions & 0 deletions internal/dbutil/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package dbutil

import (
"strings"

"github.com/go-sql-driver/mysql"
)

func IsUniqueViolation(err error, constraint string) bool {
if err == nil {
return false
}

sqlErr, ok := err.(*mysql.MySQLError)
if !ok {
return false
}

if sqlErr.Number == 1062 {
// MySQL error code 1062 is ER_DUP_ENTRY, indicating a duplicate entry error
// Extract the conflicting index name from the error message
// The error message format is like this:
// "Duplicate entry '{value}' for key '{index_name}'"
msg := sqlErr.Message
i := strings.Index(msg, "for key '")
if i == -1 {
return false
}
j := strings.Index(msg[i+len("for key '"):], "'")
if j == -1 {
return false
}
indexName := msg[i+len("for key '") : i+len("for key '")+j]
return indexName == constraint
}

return false
}
Loading