From 6fefe923257e6fa8d3ad3e5f0a49e8e7906282a4 Mon Sep 17 00:00:00 2001 From: mopemoepe Date: Wed, 9 Jul 2014 21:11:54 +0900 Subject: [PATCH] Add cache middleware --- .gitignore | 2 + cache/Godeps/Godeps.json | 25 +++++ cache/cache.go | 154 +++++++++++++++++++++++++++++ cache/cache_test.go | 203 +++++++++++++++++++++++++++++++++++++++ cache/example/example.go | 25 +++++ cache/inmemory.go | 78 +++++++++++++++ cache/inmemory_test.go | 37 +++++++ cache/memcached.go | 91 ++++++++++++++++++ cache/memcached_test.go | 46 +++++++++ cache/redis.go | 183 +++++++++++++++++++++++++++++++++++ cache/redis_test.go | 48 +++++++++ cache/serializer.go | 67 +++++++++++++ 12 files changed, 959 insertions(+) create mode 100644 .gitignore create mode 100644 cache/Godeps/Godeps.json create mode 100644 cache/cache.go create mode 100644 cache/cache_test.go create mode 100644 cache/example/example.go create mode 100644 cache/inmemory.go create mode 100644 cache/inmemory_test.go create mode 100644 cache/memcached.go create mode 100644 cache/memcached_test.go create mode 100644 cache/redis.go create mode 100644 cache/redis_test.go create mode 100644 cache/serializer.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e165012 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*/Godeps/* +!*/Godeps/Godeps.json diff --git a/cache/Godeps/Godeps.json b/cache/Godeps/Godeps.json new file mode 100644 index 0000000..c2dd80b --- /dev/null +++ b/cache/Godeps/Godeps.json @@ -0,0 +1,25 @@ +{ + "ImportPath": "github.com/gin-gonic/contrib/cache", + "GoVersion": "go1.3", + "Deps": [ + { + "ImportPath": "github.com/bradfitz/gomemcache/memcache", + "Comment": "release.r60-36-g4faecad", + "Rev": "4faecadd4f695d18a912ba110120fcfd460aca98" + }, + { + "ImportPath": "github.com/garyburd/redigo/redis", + "Rev": "1c7841955920910958c71b1cc1a03f628267d468" + }, + { + "ImportPath": "github.com/gin-gonic/gin", + "Comment": "v0.2b-6-gc224bf8", + "Rev": "c224bf82111883dbe354edf9376642f615b7248e" + }, + { + "ImportPath": "github.com/robfig/go-cache", + "Comment": "go.r60.3-61-g9fc39e0", + "Rev": "9fc39e0dbf62c034ec4e45e6120fc69433a3ec51" + } + ] +} diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..e11b032 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,154 @@ +package cache + +import ( + "bytes" + "crypto/sha1" + "errors" + "github.com/gin-gonic/gin" + "net/url" + "net/http" + "io" + "time" +) + +const ( + DEFAULT = time.Duration(0) + FOREVER = time.Duration(-1) + CACHE_MIDDLEWARE_KEY = "gincontrib.cache" +) + +var ( + PageCachePrefix = "gincontrib.page.cahe" + ErrCacheMiss = errors.New("cache: key not found.") + ErrNotStored = errors.New("cache: not stored.") + ErrNotSupport = errors.New("cache: not support.") +) + +type CacheStore interface { + Get(key string, value interface{}) error + Set(key string, value interface{}, expire time.Duration) error + Add(key string, value interface{}, expire time.Duration) error + Replace(key string, data interface{}, expire time.Duration) error + Delete(key string) error + Increment(key string, data uint64) (uint64, error) + Decrement(key string, data uint64) (uint64, error) + Flush() error +} + +type responseCache struct { + status int + header http.Header + data []byte +} + +type cachedWriter struct { + gin.ResponseWriter + status int + written bool + store CacheStore + expire time.Duration + key string +} + +func urlEscape(prefix string, u string) string { + key := url.QueryEscape(u) + if len(key) > 200 { + h := sha1.New() + io.WriteString(h, u) + key = string(h.Sum(nil)) + } + var buffer bytes.Buffer + buffer.WriteString(prefix) + buffer.WriteString(":") + buffer.WriteString(key) + return buffer.String() +} + +func newCachedWriter(store CacheStore, expire time.Duration, writer gin.ResponseWriter, key string) *cachedWriter { + return &cachedWriter{writer, 0, false, store, expire, key} +} + +func (w *cachedWriter) WriteHeader(code int) { + w.status = code + w.written = true + w.ResponseWriter.WriteHeader(code) +} + +func (w *cachedWriter) Status() int { + return w.status +} + +func (w *cachedWriter) Written() bool { + return w.written +} + +func (w *cachedWriter) Write(data []byte) (int, error) { + ret, err := w.ResponseWriter.Write(data) + if err == nil { + //cache response + store := w.store + val := responseCache{ + w.status, + w.Header(), + data, + } + err = store.Set(w.key, val, w.expire) + if err != nil { + // need logger + } + } + return ret, err +} + +// Cache Middleware +func Cache(store *CacheStore) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(CACHE_MIDDLEWARE_KEY, store) + c.Next() + } +} + +func SiteCache(store CacheStore, expire time.Duration) gin.HandlerFunc { + + return func(c *gin.Context) { + var cache responseCache + url := c.Req.URL + key := urlEscape(PageCachePrefix, url.RequestURI()) + if err := store.Get(key, &cache); err != nil { + c.Next() + } else { + c.Writer.WriteHeader(cache.status) + for k, vals := range cache.header { + for _, v := range vals { + c.Writer.Header().Add(k, v) + } + } + c.Writer.Write(cache.data) + } + } +} + +// Cache Decorator +func CachePage(store CacheStore, expire time.Duration, handle gin.HandlerFunc) gin.HandlerFunc { + + return func(c *gin.Context) { + var cache responseCache + url := c.Req.URL + key := urlEscape(PageCachePrefix, url.RequestURI()) + if err := store.Get(key, &cache); err != nil { + // replace writer + writer := newCachedWriter(store, expire, c.Writer, key) + c.Writer = writer + handle(c) + } else { + c.Writer.WriteHeader(cache.status) + for k, vals := range cache.header { + for _, v := range vals { + c.Writer.Header().Add(k, v) + } + } + c.Writer.Write(cache.data) + } + } +} + diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..dc8d47b --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,203 @@ +package cache + +import ( + "math" + "testing" + "time" +) + +type cacheFactory func(*testing.T, time.Duration) CacheStore + +// Test typical cache interactions +func typicalGetSet(t *testing.T, newCache cacheFactory) { + var err error + cache := newCache(t, time.Hour) + + value := "foo" + if err = cache.Set("value", value, DEFAULT); err != nil { + t.Errorf("Error setting a value: %s", err) + } + + value = "" + err = cache.Get("value", &value) + if err != nil { + t.Errorf("Error getting a value: %s", err) + } + if value != "foo" { + t.Errorf("Expected to get foo back, got %s", value) + } +} + +// Test the increment-decrement cases +func incrDecr(t *testing.T, newCache cacheFactory) { + var err error + cache := newCache(t, time.Hour) + + // Normal increment / decrement operation. + if err = cache.Set("int", 10, DEFAULT); err != nil { + t.Errorf("Error setting int: %s", err) + } + newValue, err := cache.Increment("int", 50) + if err != nil { + t.Errorf("Error incrementing int: %s", err) + } + if newValue != 60 { + t.Errorf("Expected 60, was %d", newValue) + } + + if newValue, err = cache.Decrement("int", 50); err != nil { + t.Errorf("Error decrementing: %s", err) + } + if newValue != 10 { + t.Errorf("Expected 10, was %d", newValue) + } + + // Increment wraparound + newValue, err = cache.Increment("int", math.MaxUint64-5) + if err != nil { + t.Errorf("Error wrapping around: %s", err) + } + if newValue != 4 { + t.Errorf("Expected wraparound 4, got %d", newValue) + } + + // Decrement capped at 0 + newValue, err = cache.Decrement("int", 25) + if err != nil { + t.Errorf("Error decrementing below 0: %s", err) + } + if newValue != 0 { + t.Errorf("Expected capped at 0, got %d", newValue) + } +} + +func expiration(t *testing.T, newCache cacheFactory) { + // memcached does not support expiration times less than 1 second. + var err error + cache := newCache(t, time.Second) + // Test Set w/ DEFAULT + value := 10 + cache.Set("int", value, DEFAULT) + time.Sleep(2 * time.Second) + err = cache.Get("int", &value) + if err != ErrCacheMiss { + t.Errorf("Expected CacheMiss, but got: %s", err) + } + + // Test Set w/ short time + cache.Set("int", value, time.Second) + time.Sleep(2 * time.Second) + err = cache.Get("int", &value) + if err != ErrCacheMiss { + t.Errorf("Expected CacheMiss, but got: %s", err) + } + + // Test Set w/ longer time. + cache.Set("int", value, time.Hour) + time.Sleep(2 * time.Second) + err = cache.Get("int", &value) + if err != nil { + t.Errorf("Expected to get the value, but got: %s", err) + } + + // Test Set w/ forever. + cache.Set("int", value, FOREVER) + time.Sleep(2 * time.Second) + err = cache.Get("int", &value) + if err != nil { + t.Errorf("Expected to get the value, but got: %s", err) + } +} + +func emptyCache(t *testing.T, newCache cacheFactory) { + var err error + cache := newCache(t, time.Hour) + + err = cache.Get("notexist", 0) + if err == nil { + t.Errorf("Error expected for non-existent key") + } + if err != ErrCacheMiss { + t.Errorf("Expected ErrCacheMiss for non-existent key: %s", err) + } + + err = cache.Delete("notexist") + if err != ErrCacheMiss { + t.Errorf("Expected ErrCacheMiss for non-existent key: %s", err) + } + + _, err = cache.Increment("notexist", 1) + if err != ErrCacheMiss { + t.Errorf("Expected cache miss incrementing non-existent key: %s", err) + } + + _, err = cache.Decrement("notexist", 1) + if err != ErrCacheMiss { + t.Errorf("Expected cache miss decrementing non-existent key: %s", err) + } +} + +func testReplace(t *testing.T, newCache cacheFactory) { + var err error + cache := newCache(t, time.Hour) + + // Replace in an empty cache. + if err = cache.Replace("notexist", 1, FOREVER); err != ErrNotStored { + t.Errorf("Replace in empty cache: expected ErrNotStored, got: %s", err) + } + + // Set a value of 1, and replace it with 2 + if err = cache.Set("int", 1, time.Second); err != nil { + t.Errorf("Unexpected error: %s", err) + } + + if err = cache.Replace("int", 2, time.Second); err != nil { + t.Errorf("Unexpected error: %s", err) + } + var i int + if err = cache.Get("int", &i); err != nil { + t.Errorf("Unexpected error getting a replaced item: %s", err) + } + if i != 2 { + t.Errorf("Expected 2, got %d", i) + } + + // Wait for it to expire and replace with 3 (unsuccessfully). + time.Sleep(2 * time.Second) + if err = cache.Replace("int", 3, time.Second); err != ErrNotStored { + t.Errorf("Expected ErrNotStored, got: %s", err) + } + if err = cache.Get("int", &i); err != ErrCacheMiss { + t.Errorf("Expected cache miss, got: %s", err) + } +} + +func testAdd(t *testing.T, newCache cacheFactory) { + var err error + cache := newCache(t, time.Hour) + // Add to an empty cache. + if err = cache.Add("int", 1, time.Second); err != nil { + t.Errorf("Unexpected error adding to empty cache: %s", err) + } + + // Try to add again. (fail) + if err = cache.Add("int", 2, time.Second); err != ErrNotStored { + t.Errorf("Expected ErrNotStored adding dupe to cache: %s", err) + } + + // Wait for it to expire, and add again. + time.Sleep(2 * time.Second) + if err = cache.Add("int", 3, time.Second); err != nil { + t.Errorf("Unexpected error adding to cache: %s", err) + } + + // Get and verify the value. + var i int + if err = cache.Get("int", &i); err != nil { + t.Errorf("Unexpected error: %s", err) + } + if i != 3 { + t.Errorf("Expected 3, got: %d", i) + } +} + diff --git a/cache/example/example.go b/cache/example/example.go new file mode 100644 index 0000000..0944f53 --- /dev/null +++ b/cache/example/example.go @@ -0,0 +1,25 @@ +package main + +import ( + "github.com/gin-gonic/gin" + "github.com/gin-gonic/contrib/cache" + "time" +) + + +func main() { + r := gin.Default() + + store := cache.NewInMemoryStore(time.Second) + // Cached Page + r.GET("/ping", func(c *gin.Context) { + c.String(200, "pong") + }) + + r.GET("/cache_ping", cache.CachePage(store, time.Minute, func(c *gin.Context) { + c.String(200, "pong") + })) + + // Listen and Server in 0.0.0.0:8080 + r.Run(":8080") +} diff --git a/cache/inmemory.go b/cache/inmemory.go new file mode 100644 index 0000000..ede5ac0 --- /dev/null +++ b/cache/inmemory.go @@ -0,0 +1,78 @@ +package cache + +import ( + "github.com/robfig/go-cache" + "reflect" + "time" +) + +type InMemoryStore struct { + cache.Cache +} + +func NewInMemoryStore(defaultExpiration time.Duration) *InMemoryStore { + return &InMemoryStore{*cache.New(defaultExpiration, time.Minute)} +} + +func (c *InMemoryStore) Get(key string, value interface{}) error { + val, found := c.Cache.Get(key) + if !found { + return ErrCacheMiss + } + + v := reflect.ValueOf(value) + if v.Type().Kind() == reflect.Ptr && v.Elem().CanSet() { + v.Elem().Set(reflect.ValueOf(val)) + return nil + } + return ErrNotStored +} + +func (c *InMemoryStore) Set(key string, value interface{}, expires time.Duration) error { + // NOTE: go-cache understands the values of DEFAULT and FOREVER + c.Cache.Set(key, value, expires) + return nil +} + +func (c *InMemoryStore) Add(key string, value interface{}, expires time.Duration) error { + err := c.Cache.Add(key, value, expires) + if err == cache.ErrKeyExists { + return ErrNotStored + } + return err +} + +func (c *InMemoryStore) Replace(key string, value interface{}, expires time.Duration) error { + if err := c.Cache.Replace(key, value, expires); err != nil { + return ErrNotStored + } + return nil +} + +func (c *InMemoryStore) Delete(key string) error { + if found := c.Cache.Delete(key); !found { + return ErrCacheMiss + } + return nil +} + +func (c *InMemoryStore) Increment(key string, n uint64) (uint64, error) { + newValue, err := c.Cache.Increment(key, n) + if err == cache.ErrCacheMiss { + return 0, ErrCacheMiss + } + return newValue, err +} + +func (c *InMemoryStore) Decrement(key string, n uint64) (uint64, error) { + newValue, err := c.Cache.Decrement(key, n) + if err == cache.ErrCacheMiss { + return 0, ErrCacheMiss + } + return newValue, err +} + +func (c *InMemoryStore) Flush() error { + c.Cache.Flush() + return nil +} diff --git a/cache/inmemory_test.go b/cache/inmemory_test.go new file mode 100644 index 0000000..a590d13 --- /dev/null +++ b/cache/inmemory_test.go @@ -0,0 +1,37 @@ +package cache + +import ( + "testing" + "time" +) + +var newInMemoryStore = func(_ *testing.T, defaultExpiration time.Duration) CacheStore { + return NewInMemoryStore(defaultExpiration) +} + +// Test typical cache interactions +func TestInMemoryCache_TypicalGetSet(t *testing.T) { + typicalGetSet(t, newInMemoryStore) +} + + +func TestInMemoryCache_IncrDecr(t *testing.T) { + incrDecr(t, newInMemoryStore) +} + +func TestInMemoryCache_Expiration(t *testing.T) { + expiration(t, newInMemoryStore) +} + +func TestInMemoryCache_EmptyCache(t *testing.T) { + emptyCache(t, newInMemoryStore) +} + +func TestInMemoryCache_Replace(t *testing.T) { + testReplace(t, newInMemoryStore) +} + +func TestInMemoryCache_Add(t *testing.T) { + testAdd(t, newInMemoryStore) +} + diff --git a/cache/memcached.go b/cache/memcached.go new file mode 100644 index 0000000..696626d --- /dev/null +++ b/cache/memcached.go @@ -0,0 +1,91 @@ +package cache + +import ( + "github.com/bradfitz/gomemcache/memcache" + "time" +) + +type MemcachedStore struct { + *memcache.Client + defaultExpiration time.Duration +} + +func NewMemcachedStore(hostList []string, defaultExpiration time.Duration) *MemcachedStore { + return &MemcachedStore{memcache.New(hostList...), defaultExpiration} +} + +func (c *MemcachedStore) Set(key string, value interface{}, expires time.Duration) error { + return c.invoke((*memcache.Client).Set, key, value, expires) +} + +func (c *MemcachedStore) Add(key string, value interface{}, expires time.Duration) error { + return c.invoke((*memcache.Client).Add, key, value, expires) +} + +func (c *MemcachedStore) Replace(key string, value interface{}, expires time.Duration) error { + return c.invoke((*memcache.Client).Replace, key, value, expires) +} + +func (c *MemcachedStore) Get(key string, value interface{}) error { + item, err := c.Client.Get(key) + if err != nil { + return convertMemcacheError(err) + } + return deserialize(item.Value, value) +} + +func (c *MemcachedStore) Delete(key string) error { + return convertMemcacheError(c.Client.Delete(key)) +} + +func (c *MemcachedStore) Increment(key string, delta uint64) (uint64, error) { + newValue, err := c.Client.Increment(key, delta) + return newValue, convertMemcacheError(err) +} + +func (c *MemcachedStore) Decrement(key string, delta uint64) (uint64, error) { + newValue, err := c.Client.Decrement(key, delta) + return newValue, convertMemcacheError(err) +} + +func (c *MemcachedStore) Flush() error { + return ErrNotSupport +} + +func (c *MemcachedStore) invoke(storeFn func(*memcache.Client, *memcache.Item) error, + key string, value interface{}, expire time.Duration) error { + + switch expire { + case DEFAULT: + expire = c.defaultExpiration + case FOREVER: + expire = time.Duration(0) + } + + b, err := serialize(value) + if err != nil { + return err + } + return convertMemcacheError(storeFn(c.Client, &memcache.Item{ + Key: key, + Value: b, + Expiration: int32(expire / time.Second), + })) +} + +func convertMemcacheError(err error) error { + switch err { + case nil: + return nil + case memcache.ErrCacheMiss: + return ErrCacheMiss + case memcache.ErrNotStored: + return ErrNotStored + } + + return err +} + + + + diff --git a/cache/memcached_test.go b/cache/memcached_test.go new file mode 100644 index 0000000..018bd3e --- /dev/null +++ b/cache/memcached_test.go @@ -0,0 +1,46 @@ +package cache + +import ( + "net" + "testing" + "time" +) + +// These tests require memcached running on localhost:11211 (the default) +const testServer = "localhost:11211" + +var newMemcachedStore = func(t *testing.T, defaultExpiration time.Duration) CacheStore { + c, err := net.Dial("tcp", testServer) + if err == nil { + c.Write([]byte("flush_all\r\n")) + c.Close() + return NewMemcachedStore([]string{testServer}, defaultExpiration) + } + t.Errorf("couldn't connect to memcached on %s", testServer) + t.FailNow() + panic("") +} + +func TestMemcachedCache_TypicalGetSet(t *testing.T) { + typicalGetSet(t, newMemcachedStore) +} + +func TestMemcachedCache_IncrDecr(t *testing.T) { + incrDecr(t, newMemcachedStore) +} + +func TestMemcachedCache_Expiration(t *testing.T) { + expiration(t, newMemcachedStore) +} + +func TestMemcachedCache_EmptyCache(t *testing.T) { + emptyCache(t, newMemcachedStore) +} + +func TestMemcachedCache_Replace(t *testing.T) { + testReplace(t, newMemcachedStore) +} + +func TestMemcachedCache_Add(t *testing.T) { + testAdd(t, newMemcachedStore) +} diff --git a/cache/redis.go b/cache/redis.go new file mode 100644 index 0000000..8b44078 --- /dev/null +++ b/cache/redis.go @@ -0,0 +1,183 @@ +package cache + +import ( + "github.com/garyburd/redigo/redis" + "time" +) + +// Wraps the Redis client to meet the Cache interface. +type RedisStore struct { + pool *redis.Pool + defaultExpiration time.Duration +} + +// until redigo supports sharding/clustering, only one host will be in hostList +func NewRedisCache(host string, password string, defaultExpiration time.Duration) *RedisStore { + var pool = &redis.Pool{ + MaxIdle: 5, + IdleTimeout: 240 * time.Second, + Dial: func() (redis.Conn, error) { + // the redis protocol should probably be made sett-able + c, err := redis.Dial("tcp", host) + if err != nil { + return nil, err + } + if len(password) > 0 { + if _, err := c.Do("AUTH", password); err != nil { + c.Close() + return nil, err + } + } else { + // check with PING + if _, err := c.Do("PING"); err != nil { + c.Close() + return nil, err + } + } + return c, err + }, + // custom connection test method + TestOnBorrow: func(c redis.Conn, t time.Time) error { + if _, err := c.Do("PING"); err != nil { + return err + } + return nil + }, + } + return &RedisStore{pool, defaultExpiration} +} + +func (c *RedisStore) Set(key string, value interface{}, expires time.Duration) error { + return c.invoke(c.pool.Get().Do, key, value, expires) +} + +func (c *RedisStore) Add(key string, value interface{}, expires time.Duration) error { + conn := c.pool.Get() + if exists(conn, key) { + return ErrNotStored + } + return c.invoke(conn.Do, key, value, expires) +} + +func (c *RedisStore) Replace(key string, value interface{}, expires time.Duration) error { + conn := c.pool.Get() + if !exists(conn, key) { + return ErrNotStored + } + err := c.invoke(conn.Do, key, value, expires) + if value == nil { + return ErrNotStored + } else { + return err + } +} + +func (c *RedisStore) Get(key string, ptrValue interface{}) error { + conn := c.pool.Get() + defer conn.Close() + raw, err := conn.Do("GET", key) + if raw == nil { + return ErrCacheMiss + } + item, err := redis.Bytes(raw, err) + if err != nil { + return err + } + return deserialize(item, ptrValue) +} + + +func exists(conn redis.Conn, key string) bool { + retval, _ := redis.Bool(conn.Do("EXISTS", key)) + return retval +} + +func (c *RedisStore) Delete(key string) error { + conn := c.pool.Get() + defer conn.Close() + if !exists(conn, key) { + return ErrCacheMiss + } + _, err := conn.Do("DEL", key) + return err +} + +func (c *RedisStore) Increment(key string, delta uint64) (uint64, error) { + conn := c.pool.Get() + defer conn.Close() + // Check for existance *before* increment as per the cache contract. + // redis will auto create the key, and we don't want that. Since we need to do increment + // ourselves instead of natively via INCRBY (redis doesn't support wrapping), we get the value + // and do the exists check this way to minimize calls to Redis + val, err := conn.Do("GET", key) + if val == nil { + return 0, ErrCacheMiss + } + if err == nil { + currentVal, err := redis.Int64(val, nil) + if err != nil { + return 0, err + } + var sum int64 = currentVal + int64(delta) + _, err = conn.Do("SET", key, sum) + if err != nil { + return 0, err + } + return uint64(sum), nil + } else { + return 0, err + } +} + +func (c *RedisStore) Decrement(key string, delta uint64) (newValue uint64, err error) { + conn := c.pool.Get() + defer conn.Close() + // Check for existance *before* increment as per the cache contract. + // redis will auto create the key, and we don't want that, hence the exists call + if !exists(conn, key) { + return 0, ErrCacheMiss + } + // Decrement contract says you can only go to 0 + // so we go fetch the value and if the delta is greater than the amount, + // 0 out the value + currentVal, err := redis.Int64(conn.Do("GET", key)) + if err == nil && delta > uint64(currentVal) { + tempint, err := redis.Int64(conn.Do("DECRBY", key, currentVal)) + return uint64(tempint), err + } + tempint, err := redis.Int64(conn.Do("DECRBY", key, delta)) + return uint64(tempint), err +} + +func (c *RedisStore) Flush() error { + conn := c.pool.Get() + defer conn.Close() + _, err := conn.Do("FLUSHALL") + return err +} + +func (c *RedisStore) invoke(f func(string, ...interface{}) (interface{}, error), + key string, value interface{}, expires time.Duration) error { + + switch expires { + case DEFAULT: + expires = c.defaultExpiration + case FOREVER: + expires = time.Duration(0) + } + + b, err := serialize(value) + if err != nil { + return err + } + conn := c.pool.Get() + defer conn.Close() + if expires > 0 { + _, err := f("SETEX", key, int32(expires/time.Second), b) + return err + } else { + _, err := f("SET", key, b) + return err + } +} + diff --git a/cache/redis_test.go b/cache/redis_test.go new file mode 100644 index 0000000..a01974c --- /dev/null +++ b/cache/redis_test.go @@ -0,0 +1,48 @@ +package cache + +import ( + "net" + "testing" + "time" +) + +// These tests require redis server running on localhost:6379 (the default) +const redisTestServer = "localhost:6379" + +var newRedisStore = func(t *testing.T, defaultExpiration time.Duration) CacheStore { + c, err := net.Dial("tcp", redisTestServer) + if err == nil { + c.Write([]byte("flush_all\r\n")) + c.Close() + redisCache := NewRedisCache(redisTestServer, "", defaultExpiration) + redisCache.Flush() + return redisCache + } + t.Errorf("couldn't connect to redis on %s", redisTestServer) + t.FailNow() + panic("") +} + +func TestRedisCache_TypicalGetSet(t *testing.T) { + typicalGetSet(t, newRedisStore) +} + +func TestRedisCache_IncrDecr(t *testing.T) { + incrDecr(t, newRedisStore) +} + +func TestRedisCache_Expiration(t *testing.T) { + expiration(t, newRedisStore) +} + +func TestRedisCache_EmptyCache(t *testing.T) { + emptyCache(t, newRedisStore) +} + +func TestRedisCache_Replace(t *testing.T) { + testReplace(t, newRedisStore) +} + +func TestRedisCache_Add(t *testing.T) { + testAdd(t, newRedisStore) +} diff --git a/cache/serializer.go b/cache/serializer.go new file mode 100644 index 0000000..958845b --- /dev/null +++ b/cache/serializer.go @@ -0,0 +1,67 @@ +package cache + +import ( + "bytes" + "encoding/gob" + "reflect" + "strconv" +) + + +func serialize(value interface{}) ([]byte, error) { + if bytes, ok := value.([]byte); ok { + return bytes, nil + } + + switch v := reflect.ValueOf(value); v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []byte(strconv.FormatInt(v.Int(), 10)), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return []byte(strconv.FormatUint(v.Uint(), 10)), nil + } + + var b bytes.Buffer + encoder := gob.NewEncoder(&b) + if err := encoder.Encode(value); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func deserialize(byt []byte, ptr interface{}) (err error) { + if bytes, ok := ptr.(*[]byte); ok { + *bytes = byt + return nil + } + + if v := reflect.ValueOf(ptr); v.Kind() == reflect.Ptr { + switch p := v.Elem(); p.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var i int64 + i, err = strconv.ParseInt(string(byt), 10, 64) + if err != nil { + return err + } else { + p.SetInt(i) + } + return nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var i uint64 + i, err = strconv.ParseUint(string(byt), 10, 64) + if err != nil { + return err + } else { + p.SetUint(i) + } + return nil + } + } + + b := bytes.NewBuffer(byt) + decoder := gob.NewDecoder(b) + if err = decoder.Decode(ptr); err != nil { + return err + } + return nil +}