From 61175fedde1f545ddb769df0309a6c6d917c3cd4 Mon Sep 17 00:00:00 2001 From: Rueian Date: Sun, 26 May 2024 10:03:24 +0800 Subject: [PATCH] feat: impl ForceWithContext in rueidislock (#547) Signed-off-by: Rueian --- rueidislock/lock.go | 63 +++++++++++--- rueidislock/lock_test.go | 175 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 11 deletions(-) diff --git a/rueidislock/lock.go b/rueidislock/lock.go index 19284f5a..035e1761 100644 --- a/rueidislock/lock.go +++ b/rueidislock/lock.go @@ -50,6 +50,8 @@ type Locker interface { WithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) // TryWithContext tries to acquire a distributed redis lock by name without waiting. It may return ErrNotLocked. TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) + // ForceWithContext takes over a distributed redis lock by canceling the original holder. It may return ErrNotLocked. + ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) // Client exports the underlying rueidis.Client Client() rueidis.Client // Close closes the underlying rueidis.Client @@ -159,13 +161,21 @@ func keyname(prefix, name string, i int32) string { return sb.String() } -func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time) (err error) { +func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time, force bool) (err error) { ctx, cancel := context.WithTimeout(ctx, m.timeout) var resp rueidis.RedisResult - if m.setpx { - resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)}) + if force { + if m.setpx { + resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)}) + } else { + resp = fcqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)}) + } } else { - resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)}) + if m.setpx { + resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)}) + } else { + resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)}) + } } cancel() if err = resp.Error(); rueidis.IsRedisNil(err) { @@ -228,6 +238,19 @@ func (m *locker) trygate(name string) (g *gate) { return g } +func (m *locker) forcegate(name string) (g *gate) { + m.mu.Lock() + if g = m.gates[name]; g == nil && m.gates != nil { + g = makegate(m.totalcnt) + m.gates[name] = g + } + if g != nil { + g.w++ + } + m.mu.Unlock() + return g +} + func (m *locker) onInvalidations(messages []rueidis.RedisMessage) { if messages == nil { m.mu.RLock() @@ -258,7 +281,7 @@ func (m *locker) onInvalidations(messages []rueidis.RedisMessage) { } } -func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string, g *gate) context.CancelFunc { +func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string, g *gate, force bool) context.CancelFunc { var err error val := random() @@ -313,13 +336,18 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string } } - acquire := func(err error, key string, ch chan struct{}) error { + acquire := func(err error, key string, ch chan struct{}, force bool) error { select { case <-ch: default: } if err != ErrNotLocked { - err = m.acquire(ctx, key, val, deadline) + if err = m.acquire(ctx, key, val, deadline, force); force && err == nil { + select { + case ch <- struct{}{}: + default: + } + } } go monitoring(err, key, deadline, ch) return err @@ -327,7 +355,7 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string var i, acquired, failures int32 for ; acquired < m.majority && failures < m.majority; i++ { - if err = acquire(err, keyname(m.prefix, name, i), g.csc[i]); err == nil { + if err = acquire(err, keyname(m.prefix, name, i), g.csc[i], force); err == nil { acquired++ } else { failures++ @@ -336,7 +364,7 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string if i < m.totalcnt { go func(i int32, err error) { for ; i < m.totalcnt; i++ { - err = acquire(err, keyname(m.prefix, name, i), g.csc[i]) + err = acquire(err, keyname(m.prefix, name, i), g.csc[i], force) } }(i, err) } @@ -349,10 +377,21 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string return nil } +func (m *locker) ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) { + ctx, cancel := context.WithCancel(ctx) + if g := m.forcegate(name); g != nil { + if cancel := m.try(ctx, cancel, name, g, true); cancel != nil { + return ctx, cancel, nil + } + } + cancel() + return ctx, cancel, ErrNotLocked +} + func (m *locker) TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) { ctx, cancel := context.WithCancel(ctx) if g := m.trygate(name); g != nil { - if cancel := m.try(ctx, cancel, name, g); cancel != nil { + if cancel := m.try(ctx, cancel, name, g, false); cancel != nil { return ctx, cancel, nil } } @@ -365,7 +404,7 @@ func (m *locker) WithContext(ctx context.Context, name string) (context.Context, ctx, cancel := context.WithCancel(ctx) g, err := m.waitgate(ctx, name) if g != nil { - if cancel := m.try(ctx, cancel, name, g); cancel != nil { + if cancel := m.try(ctx, cancel, name, g, false); cancel != nil { return ctx, cancel, nil } } @@ -394,6 +433,8 @@ var ( extend = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then local r = redis.call("PEXPIREAT",KEYS[1],ARGV[2]);redis.call("GET",KEYS[1]);return r end;return 0`) acqms = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PX",ARGV[2]);redis.call("GET",KEYS[1]);return r`) acqat = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PXAT",ARGV[2]);redis.call("GET",KEYS[1]);return r`) + fcqms = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"PX",ARGV[2]);redis.call("GET",KEYS[1]);return r`) + fcqat = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"PXAT",ARGV[2]);redis.call("GET",KEYS[1]);return r`) ) // ErrNotLocked is returned from the Locker.TryWithContext when it fails diff --git a/rueidislock/lock_test.go b/rueidislock/lock_test.go index 0504060a..8c4e29e1 100644 --- a/rueidislock/lock_test.go +++ b/rueidislock/lock_test.go @@ -162,6 +162,152 @@ func TestLocker_WithContext_UnlockByClientSideCaching(t *testing.T) { }) } +func TestLocker_WithContext_UnlockBySelfForceWithContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx bool) { + locker := newLocker(t, noLoop, setpx, false) + locker.timeout = time.Second + defer locker.Close() + lck := strconv.Itoa(rand.Int()) + ctx, cancel, err := locker.WithContext(context.Background(), lck) + if err != nil { + t.Fatal(err) + } + go func() { + _, cancel2, err2 := locker.ForceWithContext(context.Background(), lck) + if err2 != nil { + t.Errorf("unexpected err %v", err2) + return + } + cancel2() + }() + <-ctx.Done() + cancel() + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("unexpected err %v", err) + } + } + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true) + }) +} + +func TestLocker_WithContext_UnlockByOtherForceWithContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx bool) { + locker := newLocker(t, noLoop, setpx, false) + locker.timeout = time.Second + defer locker.Close() + lck := strconv.Itoa(rand.Int()) + ctx, cancel, err := locker.WithContext(context.Background(), lck) + if err != nil { + t.Fatal(err) + } + go func() { + locker2 := newLocker(t, noLoop, setpx, false) + locker2.timeout = time.Second + defer locker2.Close() + _, cancel2, err2 := locker2.ForceWithContext(context.Background(), lck) + if err2 != nil { + t.Errorf("unexpected err %v", err2) + return + } + cancel2() + }() + <-ctx.Done() + cancel() + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("unexpected err %v", err) + } + } + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true) + }) +} + +func TestLocker_ForceWithContext_UnlockBySelfForceWithContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx bool) { + locker := newLocker(t, noLoop, setpx, false) + locker.timeout = time.Second + defer locker.Close() + lck := strconv.Itoa(rand.Int()) + ctx, cancel, err := locker.ForceWithContext(context.Background(), lck) + if err != nil { + t.Fatal(err) + } + go func() { + _, cancel2, err2 := locker.ForceWithContext(context.Background(), lck) + if err2 != nil { + t.Errorf("unexpected err %v", err2) + return + } + cancel2() + }() + <-ctx.Done() + cancel() + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("unexpected err %v", err) + } + } + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true) + }) +} + +func TestLocker_ForceWithContext_UnlockByOtherForceWithContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx bool) { + locker := newLocker(t, noLoop, setpx, false) + locker.timeout = time.Second + defer locker.Close() + lck := strconv.Itoa(rand.Int()) + ctx, cancel, err := locker.ForceWithContext(context.Background(), lck) + if err != nil { + t.Fatal(err) + } + go func() { + locker2 := newLocker(t, noLoop, setpx, false) + locker2.timeout = time.Second + defer locker2.Close() + _, cancel2, err2 := locker2.ForceWithContext(context.Background(), lck) + if err2 != nil { + t.Errorf("unexpected err %v", err2) + return + } + cancel2() + }() + <-ctx.Done() + cancel() + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("unexpected err %v", err) + } + } + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true) + }) +} + func TestLocker_WithContext_ExtendByClientSideCaching(t *testing.T) { test := func(t *testing.T, noLoop, setpx bool) { locker := newLocker(t, noLoop, setpx, false) @@ -329,6 +475,35 @@ func TestLocker_TryWithContext(t *testing.T) { } } +func TestLocker_ForceWithContextThenTryWithContext(t *testing.T) { + test := func(t *testing.T, noLoop, setpx, nocsc bool) { + locker := newLocker(t, noLoop, setpx, nocsc) + locker.timeout = time.Second + defer locker.Close() + + lck := strconv.Itoa(rand.Int()) + ctx, cancel, err := locker.ForceWithContext(context.Background(), lck) + if err != nil { + t.Fatal(err) + } + if _, _, err := locker.TryWithContext(ctx, lck); err != ErrNotLocked { + t.Fatal(err) + } + cancel() + } + for _, nocsc := range []bool{false, true} { + t.Run("Tracking Loop", func(t *testing.T) { + test(t, false, false, nocsc) + }) + t.Run("Tracking NoLoop", func(t *testing.T) { + test(t, true, false, nocsc) + }) + t.Run("SET PX", func(t *testing.T) { + test(t, true, true, nocsc) + }) + } +} + func TestLocker_TryWithContext_MultipleLocker(t *testing.T) { test := func(t *testing.T, noLoop, setpx, nocsc bool) { lockers := make([]*locker, 10)