Skip to content

Commit

Permalink
feat: impl ForceWithContext in rueidislock (redis#547)
Browse files Browse the repository at this point in the history
Signed-off-by: Rueian <[email protected]>
  • Loading branch information
rueian authored May 26, 2024
1 parent c009c42 commit 61175fe
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 11 deletions.
63 changes: 52 additions & 11 deletions rueidislock/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type Locker interface {
WithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error)
// TryWithContext tries to acquire a distributed redis lock by name without waiting. It may return ErrNotLocked.
TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error)
// ForceWithContext takes over a distributed redis lock by canceling the original holder. It may return ErrNotLocked.
ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error)
// Client exports the underlying rueidis.Client
Client() rueidis.Client
// Close closes the underlying rueidis.Client
Expand Down Expand Up @@ -159,13 +161,21 @@ func keyname(prefix, name string, i int32) string {
return sb.String()
}

func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time) (err error) {
func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time, force bool) (err error) {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
var resp rueidis.RedisResult
if m.setpx {
resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
if force {
if m.setpx {
resp = fcqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
} else {
resp = fcqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
}
} else {
resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
if m.setpx {
resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
} else {
resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
}
}
cancel()
if err = resp.Error(); rueidis.IsRedisNil(err) {
Expand Down Expand Up @@ -228,6 +238,19 @@ func (m *locker) trygate(name string) (g *gate) {
return g
}

func (m *locker) forcegate(name string) (g *gate) {
m.mu.Lock()
if g = m.gates[name]; g == nil && m.gates != nil {
g = makegate(m.totalcnt)
m.gates[name] = g
}
if g != nil {
g.w++
}
m.mu.Unlock()
return g
}

func (m *locker) onInvalidations(messages []rueidis.RedisMessage) {
if messages == nil {
m.mu.RLock()
Expand Down Expand Up @@ -258,7 +281,7 @@ func (m *locker) onInvalidations(messages []rueidis.RedisMessage) {
}
}

func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string, g *gate) context.CancelFunc {
func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string, g *gate, force bool) context.CancelFunc {
var err error

val := random()
Expand Down Expand Up @@ -313,21 +336,26 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
}
}

acquire := func(err error, key string, ch chan struct{}) error {
acquire := func(err error, key string, ch chan struct{}, force bool) error {
select {
case <-ch:
default:
}
if err != ErrNotLocked {
err = m.acquire(ctx, key, val, deadline)
if err = m.acquire(ctx, key, val, deadline, force); force && err == nil {
select {
case ch <- struct{}{}:
default:
}
}
}
go monitoring(err, key, deadline, ch)
return err
}

var i, acquired, failures int32
for ; acquired < m.majority && failures < m.majority; i++ {
if err = acquire(err, keyname(m.prefix, name, i), g.csc[i]); err == nil {
if err = acquire(err, keyname(m.prefix, name, i), g.csc[i], force); err == nil {
acquired++
} else {
failures++
Expand All @@ -336,7 +364,7 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
if i < m.totalcnt {
go func(i int32, err error) {
for ; i < m.totalcnt; i++ {
err = acquire(err, keyname(m.prefix, name, i), g.csc[i])
err = acquire(err, keyname(m.prefix, name, i), g.csc[i], force)
}
}(i, err)
}
Expand All @@ -349,10 +377,21 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
return nil
}

func (m *locker) ForceWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
ctx, cancel := context.WithCancel(ctx)
if g := m.forcegate(name); g != nil {
if cancel := m.try(ctx, cancel, name, g, true); cancel != nil {
return ctx, cancel, nil
}
}
cancel()
return ctx, cancel, ErrNotLocked
}

func (m *locker) TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
ctx, cancel := context.WithCancel(ctx)
if g := m.trygate(name); g != nil {
if cancel := m.try(ctx, cancel, name, g); cancel != nil {
if cancel := m.try(ctx, cancel, name, g, false); cancel != nil {
return ctx, cancel, nil
}
}
Expand All @@ -365,7 +404,7 @@ func (m *locker) WithContext(ctx context.Context, name string) (context.Context,
ctx, cancel := context.WithCancel(ctx)
g, err := m.waitgate(ctx, name)
if g != nil {
if cancel := m.try(ctx, cancel, name, g); cancel != nil {
if cancel := m.try(ctx, cancel, name, g, false); cancel != nil {
return ctx, cancel, nil
}
}
Expand Down Expand Up @@ -394,6 +433,8 @@ var (
extend = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then local r = redis.call("PEXPIREAT",KEYS[1],ARGV[2]);redis.call("GET",KEYS[1]);return r end;return 0`)
acqms = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PX",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
acqat = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PXAT",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
fcqms = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"PX",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
fcqat = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"PXAT",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
)

// ErrNotLocked is returned from the Locker.TryWithContext when it fails
Expand Down
175 changes: 175 additions & 0 deletions rueidislock/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,152 @@ func TestLocker_WithContext_UnlockByClientSideCaching(t *testing.T) {
})
}

func TestLocker_WithContext_UnlockBySelfForceWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx bool) {
locker := newLocker(t, noLoop, setpx, false)
locker.timeout = time.Second
defer locker.Close()
lck := strconv.Itoa(rand.Int())
ctx, cancel, err := locker.WithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}
go func() {
_, cancel2, err2 := locker.ForceWithContext(context.Background(), lck)
if err2 != nil {
t.Errorf("unexpected err %v", err2)
return
}
cancel2()
}()
<-ctx.Done()
cancel()
if !errors.Is(ctx.Err(), context.Canceled) {
t.Fatalf("unexpected err %v", err)
}
}
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true)
})
}

func TestLocker_WithContext_UnlockByOtherForceWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx bool) {
locker := newLocker(t, noLoop, setpx, false)
locker.timeout = time.Second
defer locker.Close()
lck := strconv.Itoa(rand.Int())
ctx, cancel, err := locker.WithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}
go func() {
locker2 := newLocker(t, noLoop, setpx, false)
locker2.timeout = time.Second
defer locker2.Close()
_, cancel2, err2 := locker2.ForceWithContext(context.Background(), lck)
if err2 != nil {
t.Errorf("unexpected err %v", err2)
return
}
cancel2()
}()
<-ctx.Done()
cancel()
if !errors.Is(ctx.Err(), context.Canceled) {
t.Fatalf("unexpected err %v", err)
}
}
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true)
})
}

func TestLocker_ForceWithContext_UnlockBySelfForceWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx bool) {
locker := newLocker(t, noLoop, setpx, false)
locker.timeout = time.Second
defer locker.Close()
lck := strconv.Itoa(rand.Int())
ctx, cancel, err := locker.ForceWithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}
go func() {
_, cancel2, err2 := locker.ForceWithContext(context.Background(), lck)
if err2 != nil {
t.Errorf("unexpected err %v", err2)
return
}
cancel2()
}()
<-ctx.Done()
cancel()
if !errors.Is(ctx.Err(), context.Canceled) {
t.Fatalf("unexpected err %v", err)
}
}
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true)
})
}

func TestLocker_ForceWithContext_UnlockByOtherForceWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx bool) {
locker := newLocker(t, noLoop, setpx, false)
locker.timeout = time.Second
defer locker.Close()
lck := strconv.Itoa(rand.Int())
ctx, cancel, err := locker.ForceWithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}
go func() {
locker2 := newLocker(t, noLoop, setpx, false)
locker2.timeout = time.Second
defer locker2.Close()
_, cancel2, err2 := locker2.ForceWithContext(context.Background(), lck)
if err2 != nil {
t.Errorf("unexpected err %v", err2)
return
}
cancel2()
}()
<-ctx.Done()
cancel()
if !errors.Is(ctx.Err(), context.Canceled) {
t.Fatalf("unexpected err %v", err)
}
}
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true)
})
}

func TestLocker_WithContext_ExtendByClientSideCaching(t *testing.T) {
test := func(t *testing.T, noLoop, setpx bool) {
locker := newLocker(t, noLoop, setpx, false)
Expand Down Expand Up @@ -329,6 +475,35 @@ func TestLocker_TryWithContext(t *testing.T) {
}
}

func TestLocker_ForceWithContextThenTryWithContext(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
locker := newLocker(t, noLoop, setpx, nocsc)
locker.timeout = time.Second
defer locker.Close()

lck := strconv.Itoa(rand.Int())
ctx, cancel, err := locker.ForceWithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}
if _, _, err := locker.TryWithContext(ctx, lck); err != ErrNotLocked {
t.Fatal(err)
}
cancel()
}
for _, nocsc := range []bool{false, true} {
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false, nocsc)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false, nocsc)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true, nocsc)
})
}
}

func TestLocker_TryWithContext_MultipleLocker(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
lockers := make([]*locker, 10)
Expand Down

0 comments on commit 61175fe

Please sign in to comment.