diff --git a/api/v1/oauth_server.go b/api/v1/oauth_server.go index e42ec17..e6ce34f 100644 --- a/api/v1/oauth_server.go +++ b/api/v1/oauth_server.go @@ -28,9 +28,8 @@ import ( var ( srv *server.Server - pgxConn, _ = pgx.Connect(context.TODO(), config.Config.Sub("oauth.server").GetString("db_uri")) + pgxConn, _ = pgx.Connect(context.Background(), config.Config.Sub("oauth.server").GetString("db_uri")) adapter = pgx4adapter.NewConn(pgxConn) - clientStore, _ = pg.NewClientStore(adapter) ) func init() { @@ -41,6 +40,7 @@ func InitServer() { // use PostgreSQL token store with pgx.Connection adapter tokenStore, _ := pg.NewTokenStore(adapter, pg.WithTokenStoreGCInterval(time.Minute)) defer tokenStore.Close() + clientStore, _ := pg.NewClientStore(adapter) mg := manage.NewDefaultManager() mg.MapTokenStorage(tokenStore) @@ -66,7 +66,6 @@ func InitServer() { srv.SetResponseErrorHandler(func(re *errors.Response) { log.Println("Response Error:", re.Error.Error()) }) - } // Create client @@ -77,6 +76,13 @@ func CreateClient(c *gin.Context) { return } + token := c.GetHeader("TOKEN") + uid, err := util.GetUsername(token, model.LOGIN_TOKEN_SUB) + if err != nil || uid == "" { + c.JSON(http.StatusOK, result.Failed(result.TokenError)) + return + } + clientID := util.GenerateUUID() secret, err := util.GenerateRandomString(32) if err != nil { @@ -84,11 +90,14 @@ func CreateClient(c *gin.Context) { return } + clientStore, _ := pg.NewClientStore(adapter) cErr := clientStore.Create(&models.Client{ ID: clientID, Secret: secret, Domain: redirectURI, + UserID: uid, }) + if cErr != nil { c.JSON(http.StatusBadRequest, result.Failed(result.InternalErr)) return diff --git a/api/v1/user.go b/api/v1/user.go index 00ccd99..ff12107 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -242,7 +242,7 @@ func ChangePassword(ctx *gin.Context) { token := ctx.GetHeader("TOKEN") uid, err := util.GetUsername(token, model.LOGIN_TOKEN_SUB) if err != nil || uid == "" { - ctx.JSON(http.StatusOK, result.Failed(result.TicketNotCorrect)) + ctx.JSON(http.StatusOK, result.Failed(result.TokenError)) return } // Get password from form diff --git a/example/client.go b/example/client.go index 9fb6038..853f29a 100644 --- a/example/client.go +++ b/example/client.go @@ -45,11 +45,11 @@ func main() { http.HandleFunc("/api/auth/callback/sastlink", func(w http.ResponseWriter, r *http.Request) { r.ParseForm() println(r.URL.RawQuery) - state := r.Form.Get("state") - if state != "xyz" { - http.Error(w, "State invalid", http.StatusBadRequest) - return - } + // state := r.Form.Get("state") + // if state != "xyz" { + // http.Error(w, "State invalid", http.StatusBadRequest) + // return + // } code := r.Form.Get("code") if code == "" { http.Error(w, "Code not found", http.StatusBadRequest) @@ -61,11 +61,12 @@ func main() { http.HandleFunc("/oauth2", func(w http.ResponseWriter, r *http.Request) { r.ParseForm() println(r.URL.RawQuery) - state := r.Form.Get("state") - if state != "xyz" { - http.Error(w, "State invalid", http.StatusBadRequest) - return - } + verifier := oauth2.GenerateVerifier() + // state := r.Form.Get("state") + // if state != "xyz" { + // http.Error(w, "State invalid", http.StatusBadRequest) + // return + // } code := r.Form.Get("code") if code == "" { http.Error(w, "Code not found", http.StatusBadRequest)