From acdc0e85d9cfb43d4f2758d9559a7d5f6591efa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?okhowang=28=E7=8E=8B=E6=B2=9B=E6=96=87=29?= Date: Fri, 8 Nov 2024 11:59:00 +0800 Subject: [PATCH] fix(loadable): cache value in setChannel Fixes #251 --- lib/cache/loadable.go | 6 ++++++ lib/cache/loadable_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/lib/cache/loadable.go b/lib/cache/loadable.go index e2e642eb..efd81e39 100644 --- a/lib/cache/loadable.go +++ b/lib/cache/loadable.go @@ -28,6 +28,7 @@ type LoadableCache[T any] struct { loadFunc LoadFunction[T] cache CacheInterface[T] setChannel chan *loadableKeyValue[T] + setCache sync.Map setterWg *sync.WaitGroup } @@ -55,6 +56,7 @@ func (c *LoadableCache[T]) setter() { cacheKey := c.getCacheKey(item.key) c.singleFlight.Forget(cacheKey) + c.setCache.Delete(cacheKey) } } @@ -69,6 +71,9 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) { // Unable to find in cache, try to load it from load function cacheKey := c.getCacheKey(key) + if v, ok := c.setCache.Load(cacheKey); ok { + return v.(T), nil + } zero := *new(T) loadedResult, err, _ := c.singleFlight.Do( @@ -89,6 +94,7 @@ func (c *LoadableCache[T]) Get(ctx context.Context, key any) (T, error) { } // Then, put it back in cache + c.setCache.Store(cacheKey, object) c.setChannel <- &loadableKeyValue[T]{key, object} return object, err diff --git a/lib/cache/loadable_test.go b/lib/cache/loadable_test.go index 26bd0d6b..a909e106 100644 --- a/lib/cache/loadable_test.go +++ b/lib/cache/loadable_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) @@ -299,3 +300,29 @@ func TestLoadableGetType(t *testing.T) { // When - Then assert.Equal(t, LoadableType, cache.GetType()) } + +func TestLoadableGetTwice(t *testing.T) { + // Given + ctrl := gomock.NewController(t) + + cache1 := NewMockSetterCacheInterface[any](ctrl) + + var counter atomic.Uint64 + loadFunc := func(_ context.Context, key any) (any, error) { + return counter.Add(1), nil + } + + cache := NewLoadable[any](loadFunc, cache1) + + key := 1 + cache1.EXPECT().Get(context.Background(), key).Return(nil, store.NotFound{}).Times(2) + cache1.EXPECT().Set(context.Background(), key, uint64(1)).Times(1) + v1, err1 := cache.Get(context.Background(), key) + v2, err2 := cache.Get(context.Background(), key) // setter may not be called now because it's done by another goroutine + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.Equal(t, uint64(1), v1) + assert.Equal(t, uint64(1), v2) + assert.Equal(t, uint64(1), counter.Load()) + _ = cache.Close() // wait for setter +}