Skip to content

Commit

Permalink
Merge pull request #253 from actiontech/issue-256-1
Browse files Browse the repository at this point in the history
记录下发sql的审计日志
  • Loading branch information
sjjian authored May 24, 2024
2 parents 95ced52 + fdbbd4c commit a5b1e6a
Show file tree
Hide file tree
Showing 15 changed files with 495 additions and 49 deletions.
10 changes: 5 additions & 5 deletions api/dms/service/v1/cb_operation_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// swagger:parameters ListCBOperationLogs
type ListCBOperationLogsReq struct {
cbOperationLogsReq
CbOperationLogsReq
// the maximum count of member to be returned
// in:query
// Required: true
Expand Down Expand Up @@ -68,7 +68,7 @@ type UidWithDBServiceName struct {

// swagger:parameters ExportCBOperationLogs
type ExportCBOperationLogsReq struct {
cbOperationLogsReq
CbOperationLogsReq
}

// swagger:response ExportCBOperationLogsReply
Expand All @@ -78,7 +78,7 @@ type ExportCBOperationLogsReply struct {
File []byte
}

type cbOperationLogsReq struct {
type CbOperationLogsReq struct {
// project id
// Required: true
// in:path
Expand Down Expand Up @@ -110,9 +110,9 @@ type GetCBOperationLogTipsReq struct {
type GetCBOperationLogTipsReply struct {
// Generic reply
base.GenericResp
Data *cBOperationLogTips `json:"data"`
Data *CBOperationLogTips `json:"data"`
}

type cBOperationLogTips struct {
type CBOperationLogTips struct {
ExecResult []string `json:"exec_result"`
}
36 changes: 34 additions & 2 deletions internal/apiserver/service/dms_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2559,7 +2559,23 @@ func (d *DMSController) ListMaskingRules(c echo.Context) error {
// 200: body:ListCBOperationLogsReply
// default: body:GenericResp
func (d *DMSController) ListCBOperationLogs(c echo.Context) error {
return nil
req := &aV1.ListCBOperationLogsReq{}
err := bindAndValidateReq(c, req)
if nil != err {
return NewErrResp(c, err, apiError.BadRequestErr)
}

currentUserUid, err := jwt.GetUserUidStrFromContext(c)
if err != nil {
return NewErrResp(c, err, apiError.DMSServiceErr)
}

reply, err := d.DMS.ListCBOperationLogs(c.Request().Context(), req, currentUserUid)
if err != nil {
return NewErrResp(c, err, apiError.APIServerErr)
}

return NewOkRespWithReply(c, reply)
}

// swagger:route GET /v1/dms/projects/{project_uid}/cb_operation_logs/export dms ExportCBOperationLogs
Expand All @@ -2581,5 +2597,21 @@ func (d *DMSController) ExportCBOperationLogs(c echo.Context) error {
// 200: GetCBOperationLogTipsReply
// default: body:GenericResp
func (a *DMSController) GetCBOperationLogTips(c echo.Context) error {
return nil
req := &aV1.GetCBOperationLogTipsReq{}
err := bindAndValidateReq(c, req)
if nil != err {
return NewErrResp(c, err, apiError.BadRequestErr)
}

currentUserUid, err := jwt.GetUserUidStrFromContext(c)
if err != nil {
return NewErrResp(c, err, apiError.DMSServiceErr)
}

reply, err := a.DMS.GetCBOperationLogTips(c.Request().Context(), req, currentUserUid)
if err != nil {
return NewErrResp(c, err, apiError.APIServerErr)
}

return NewOkRespWithReply(c, reply)
}
82 changes: 82 additions & 0 deletions internal/dms/biz/cb_operation_log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package biz

import (
"context"
"time"

"github.com/actiontech/dms/internal/dms/pkg/constant"
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"
)

type CbOperationLogType string

const (
CbOperationLogTypeSql CbOperationLogType = "SQL"
)

// CbOperationLogRepo 定义操作日志的存储接口
type CbOperationLogRepo interface {
GetCbOperationLogByID(ctx context.Context, uid string) (*CbOperationLog, error)
SaveCbOperationLog(ctx context.Context, log *CbOperationLog) error
UpdateCbOperationLog(ctx context.Context, log *CbOperationLog) error
ListCbOperationLogs(ctx context.Context, opt *ListCbOperationLogOption) ([]*CbOperationLog, int64, error)
}

// CbOperationLog 代表操作日志记录
type CbOperationLog struct {
UID string
OpPersonUID string
OpTime *time.Time
DBServiceUID string
OpType CbOperationLogType
OpDetail string
OpSessionID *string
OpHost string
ProjectID string
AuditResults []*AuditResult
IsAuditPass *bool
ExecResult string
ExecTotalSec int64
ResultSetRowCount int64

User *User
DbService *DBService
}

func (c CbOperationLog) GetOpTime() time.Time {
if c.OpTime != nil {
return *c.OpTime
}
return time.Time{}
}

func (c CbOperationLog) GetSessionID() string {
if c.OpSessionID != nil {
return *c.OpSessionID
}
return ""
}

// ListCbOperationLogOption 用于查询操作日志的选项
type ListCbOperationLogOption struct {
PageNumber uint32
LimitPerPage uint32
OrderBy string
FilterBy []constant.FilterCondition
}

// CbOperationLogUsecase 定义操作日志的业务逻辑
type CbOperationLogUsecase struct {
opPermissionVerifyUsecase *OpPermissionVerifyUsecase
repo CbOperationLogRepo
log *utilLog.Helper
}

// NewCbOperationLogUsecase 创建一个新的操作日志业务逻辑实例
func NewCbOperationLogUsecase(logger utilLog.Logger, repo CbOperationLogRepo, opPermissionVerifyUsecase *OpPermissionVerifyUsecase) *CbOperationLogUsecase {
return &CbOperationLogUsecase{
repo: repo,
log: utilLog.NewHelper(logger, utilLog.WithMessageKey("biz.cbOperationLog")),
opPermissionVerifyUsecase: opPermissionVerifyUsecase,
}
}
26 changes: 26 additions & 0 deletions internal/dms/biz/cb_operation_log_ce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//go:build !enterprise

package biz

import (
"context"
"errors"
)

var errNotSupportCbOperationLog = errors.New("cb operation log related functions are enterprise version functions")

func (cu *CbOperationLogUsecase) GetCbOperationLogByID(ctx context.Context, uid string) (*CbOperationLog, error) {
return nil, errNotSupportCbOperationLog
}

func (u *CbOperationLogUsecase) SaveCbOperationLog(ctx context.Context, log *CbOperationLog) error {
return errNotSupportCbOperationLog
}

func (u *CbOperationLogUsecase) UpdateCbOperationLog(ctx context.Context, log *CbOperationLog) error {
return errNotSupportCbOperationLog
}

func (u *CbOperationLogUsecase) ListCbOperationLog(ctx context.Context, option *ListCbOperationLogOption, currentUid string, filterPersonID string, projectUid string) ([]*CbOperationLog, int64, error) {
return nil, 0, errNotSupportCbOperationLog
}
85 changes: 76 additions & 9 deletions internal/dms/biz/cloudbeaver.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,13 @@ type CloudbeaverUsecase struct {
opPermissionVerifyUsecase *OpPermissionVerifyUsecase
dmsConfigUseCase *DMSConfigUseCase
dataMaskingUseCase *DataMaskingUsecase
cbOperationLogUsecase *CbOperationLogUsecase
projectUsecase *ProjectUsecase
repo CloudbeaverRepo
proxyTargetRepo ProxyTargetRepo
}

func NewCloudbeaverUsecase(log utilLog.Logger, cfg *CloudbeaverCfg,
userUsecase *UserUsecase,
dbServiceUsecase *DBServiceUsecase,
opPermissionVerifyUsecase *OpPermissionVerifyUsecase,
dmsConfigUseCase *DMSConfigUseCase,
dataMaskingUseCase *DataMaskingUsecase,
cloudbeaverRepo CloudbeaverRepo,
proxyTargetRepo ProxyTargetRepo) (cu *CloudbeaverUsecase) {
func NewCloudbeaverUsecase(log utilLog.Logger, cfg *CloudbeaverCfg, userUsecase *UserUsecase, dbServiceUsecase *DBServiceUsecase, opPermissionVerifyUsecase *OpPermissionVerifyUsecase, dmsConfigUseCase *DMSConfigUseCase, dataMaskingUseCase *DataMaskingUsecase, cloudbeaverRepo CloudbeaverRepo, proxyTargetRepo ProxyTargetRepo, cbOperationUseDase *CbOperationLogUsecase, projectUsecase *ProjectUsecase) (cu *CloudbeaverUsecase) {
cu = &CloudbeaverUsecase{
repo: cloudbeaverRepo,
proxyTargetRepo: proxyTargetRepo,
Expand All @@ -101,6 +96,8 @@ func NewCloudbeaverUsecase(log utilLog.Logger, cfg *CloudbeaverCfg,
opPermissionVerifyUsecase: opPermissionVerifyUsecase,
dmsConfigUseCase: dmsConfigUseCase,
dataMaskingUseCase: dataMaskingUseCase,
cbOperationLogUsecase: cbOperationUseDase,
projectUsecase: projectUsecase,
cloudbeaverCfg: cfg,
log: utilLog.NewHelper(log, utilLog.WithMessageKey("biz.cloudbeaver")),
}
Expand Down Expand Up @@ -144,6 +141,8 @@ func (cu *CloudbeaverUsecase) getGraphQLServerURI() string {
return fmt.Sprintf("%v://%v:%v%v%v", protocol, cu.cloudbeaverCfg.Host, cu.cloudbeaverCfg.Port, CbRootUri, CbGqlApi)
}

const dmsUserIdKey = "dmsToken"

func (cu *CloudbeaverUsecase) Login() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
Expand All @@ -164,7 +163,8 @@ func (cu *CloudbeaverUsecase) Login() echo.MiddlewareFunc {
cu.log.Errorf("GetUserUidStrFromContext err: %v", err)
return errors.New("get user name from token failed")
}

// set dmsUserId to context for save ob operation log
c.Set(dmsUserIdKey, dmsUserId)
if err = cu.initialGraphQL(); err != nil {
return err
}
Expand Down Expand Up @@ -268,6 +268,7 @@ type TaskInfo struct {
} `json:"data"`
}

var taskIDAssocUid sync.Map
var taskIdAssocMasking sync.Map

func (cu *CloudbeaverUsecase) buildTaskIdAssocDataMasking(raw []byte, enableMasking bool) error {
Expand Down Expand Up @@ -341,6 +342,11 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
return err
}

err = cu.SaveCbOpLog(c, dbService, params, next)
if err != nil {
return err
}

if !cu.isEnableSQLAudit(dbService) {
cloudbeaverResBuf := new(bytes.Buffer)
mw := io.MultiWriter(c.Response().Writer, cloudbeaverResBuf)
Expand Down Expand Up @@ -373,6 +379,20 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
}

if params.OperationName == "getSqlExecuteTaskResults" {
cloudbeaverResBuf := new(bytes.Buffer)
mw := io.MultiWriter(c.Response().Writer, cloudbeaverResBuf)
writer := &cloudbeaverResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer

if err = next(c); err != nil {
return err
}

err = cu.UpdateCbOpResult(c, cloudbeaverResBuf, params, ctx)
if err != nil {
return err
}

taskIdAssocMaskingVal, exist := taskIdAssocMasking.LoadAndDelete(params.Variables["taskId"])
if !exist {
return next(c)
Expand All @@ -393,8 +413,17 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {

var cloudbeaverNext cloudbeaver.Next
var resWrite *responseProcessWriter
var resp cloudbeaver.AuditResults
if !cloudbeaverHandle.NeedModifyRemoteRes {
cloudbeaverNext = func(c echo.Context) ([]byte, error) {
resp, ok = c.Get(cloudbeaver.AuditResultKey).(cloudbeaver.AuditResults)
if ok {
cu.UpdateCbOp(params, ctx, resp)
if !resp.IsSuccess {
return nil, c.JSON(http.StatusOK, convertToResp(resp))
}
}

cloudbeaverResBuf := new(bytes.Buffer)
if params.OperationName == "asyncSqlExecuteQuery" {
mw := io.MultiWriter(c.Response().Writer, cloudbeaverResBuf)
Expand All @@ -416,6 +445,12 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
}
} else {
cloudbeaverNext = func(c echo.Context) ([]byte, error) {
resp, ok = c.Get(cloudbeaver.AuditResultKey).(cloudbeaver.AuditResults)
cu.UpdateCbOp(params, ctx, resp)
if !resp.IsSuccess {
return nil, c.JSON(http.StatusOK, convertToResp(resp))
}

resWrite = &responseProcessWriter{tmp: &bytes.Buffer{}, ResponseWriter: c.Response().Writer}
c.Response().Writer = resWrite

Expand Down Expand Up @@ -461,6 +496,38 @@ func (cu *CloudbeaverUsecase) GraphQLDistributor() echo.MiddlewareFunc {
}
}

func convertToResp(resp cloudbeaver.AuditResults) interface{} {
var messages []string
for _, sqlResult := range resp.Results {
for _, audit := range sqlResult.AuditResult {
messages = append(messages, audit.Message)
}
}

messageStr := strings.Join(messages, ",")
name := "SQL Audit Failed"

return struct {
Data struct {
TaskInfo model.AsyncTaskInfo `json:"taskInfo"`
} `json:"data"`
}{
struct {
TaskInfo model.AsyncTaskInfo `json:"taskInfo"`
}{
TaskInfo: model.AsyncTaskInfo{
Name: &name,
Running: false,
Status: &resp.SQL,
Error: &model.ServerError{
Message: &messageStr,
StackTrace: &messageStr,
},
},
},
}
}

type cloudbeaverResponseWriter struct {
io.Writer
http.ResponseWriter
Expand Down
17 changes: 17 additions & 0 deletions internal/dms/biz/cloudbeaver_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,26 @@
package biz

import (
"bytes"
"context"

"github.com/99designs/gqlgen/graphql"
"github.com/actiontech/dms/internal/pkg/cloudbeaver"
"github.com/labstack/echo/v4"
)

func (cu *CloudbeaverUsecase) ResetDbServiceByAuth(ctx context.Context, activeDBServices []*DBService, userId string) ([]*DBService, error) {
return activeDBServices, nil
}

func (cu *CloudbeaverUsecase) UpdateCbOp(params *graphql.RawParams, ctx context.Context, resp cloudbeaver.AuditResults) {
return
}

func (cu *CloudbeaverUsecase) UpdateCbOpResult(c echo.Context, cloudbeaverResBuf *bytes.Buffer, params *graphql.RawParams, ctx context.Context) error {
return nil
}

func (cu *CloudbeaverUsecase) SaveCbOpLog(c echo.Context, dbService *DBService, params *graphql.RawParams, next echo.HandlerFunc) error {
return nil
}
Loading

0 comments on commit a5b1e6a

Please sign in to comment.