Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC]: Correct Redis Rate Limiting Lua Script and Improve Documentation #1736

Open
avevotsira opened this issue Oct 31, 2024 · 2 comments
Open

Comments

@avevotsira
Copy link

avevotsira commented Oct 31, 2024

https://lucia-auth.com/rate-limit/token-bucket

-- Returns 1 if allowed, 0 if not
local key                   = KEYS[2]
local max                   = tonumber(ARGV[2])
local refillIntervalSeconds = tonumber(ARGV[3])
local cost                  = tonumber(ARGV[4])
local now                   = tonumber(ARGV[5]) -- Current unix time in seconds

local fields = redis.call("HGETALL", key)
if #fields == 0 then
    redis.call("HSET", key, "count", max - cost, "refilled_at", now)
    return {1}
end
local count = 0
local refilledAt = 0
for i = 1, #fields, 2 do
    if fields[i] == "count" then
        count = tonumber(fields[i+1])
    elseif fields[i] == "refilled_at" then
        refilledAt = tonumber(fields[i+1])
    end
end
local refill = math.floor((now - refilledAt) / refillIntervalSeconds)
count = math.min(count + refill, max)
refilledAt = now
if count < cost then
    return {0}
end
count = count - cost
redis.call("HSET", key, "count", count, "refilled_at", now)
return {1}

refilledAt should be:

refilledAt = refilledAt + refill * refillIntervalSeconds

and not

refilledAt = now

Based on my understanding, here should be how it would go if we use now?

Example (interval at 60 seconds):

t=0: refilledAt=0
t=90: should refill 1 interval and set refilledAt=60 (but code sets it to 90)
t=150: should refill 1 interval and set refilledAt=120 (but code sets it to 150)

Also, shouldn't we set the EXPIRES for the keys , so the values won't stay in Redis if it not used?

I think the final script should look something like this:

local key                   = KEYS[1]
local max                   = tonumber(ARGV[1])
local refillIntervalSeconds = tonumber(ARGV[2])
local cost                  = tonumber(ARGV[3])
local now                   = tonumber(ARGV[4]) -- Current unix time in seconds
local ttlSeconds           = tonumber(ARGV[5]) -- Add TTL parameter

local fields = redis.call("HGETALL", key)
if #fields == 0 then
    redis.call("HSET", key, "count", max - cost, "refilled_at", now)
    redis.call("EXPIRE", key, ttlSeconds) 
    return {1}
end

local count = 0
local refilledAt = 0
for i = 1, #fields, 2 do
    if fields[i] == "count" then
        count = tonumber(fields[i+1])
    elseif fields[i] == "refilled_at" then
        refilledAt = tonumber(fields[i+1])
    end
end

local refill = math.floor((now - refilledAt) / refillIntervalSeconds)
count = math.min(count + refill, max)
refilledAt = refilledAt + (refill * refillIntervalSeconds)

if count < cost then
    return {0}
end

count = count - cost
redis.call("HSET", key, "count", count, "refilled_at", refilledAt)
redis.call("EXPIRE", key, ttlSeconds) 
return {1}

and lastly , While this is not redis doc , I think having more instruction (init cilents , how to load the script) for it would help for someone new like me since as right now , we don't have rate limit example with redis yet to follow

for example here how i do mine :

// redis.ts
import { env } from "@/env";

import { createClient } from "@redis/client";

const client = createClient({
  url: env.REDIS_URL,
});

client.on("error", (err) => {
  console.error("❌ Redis Client Error:", err);
});

client.on("connect", () => {
  console.log("🔄 Redis client attempting to connect...");
});

process.on("SIGINT", async () => {
  console.log("🛑 Shutting down Redis connection...");
  await client.quit();
  process.exit();
});

client.connect().then(() => {
  console.log("✅ Redis client connected");
});

export { client };

and script :

//script.ts
export const tokenBucketScript = `
-- Returns 1 if allowed, 0 if not
local key                   = KEYS[1]
local max                   = tonumber(ARGV[1])
local refillIntervalSeconds = tonumber(ARGV[2])
local cost                  = tonumber(ARGV[3])
local now                   = tonumber(ARGV[4]) 
local ttlSeconds           = tonumber(ARGV[5])

local fields = redis.call("HGETALL", key)
if #fields == 0 then
    redis.call("HSET", key, "count", max - cost, "refilled_at", now)
    redis.call("EXPIRE", key, ttlSeconds) 
    return {1}
end

local count = 0
local refilledAt = 0
for i = 1, #fields, 2 do
    if fields[i] == "count" then
        count = tonumber(fields[i+1])
    elseif fields[i] == "refilled_at" then
        refilledAt = tonumber(fields[i+1])
    end
end

local refill = math.floor((now - refilledAt) / refillIntervalSeconds)
count = math.min(count + refill, max)
refilledAt = now

if count < cost then
    return {0}
end

count = count - cost
redis.call("HSET", key, "count", count, "refilled_at", now)
redis.call("EXPIRE", key, ttlSeconds) 
return {1}
`;

export const TOKEN_BUCKET_SHA: Readonly<string> =
  await client.scriptLoad(tokenBucketScript);

which then I load using

//token-bucket.ts
import { client } from "./redis";
import { TOKEN_BUCKET_SHA } from "./script";

export class TokenBucket {
  private storageKey: string;
  public max: number;
  public refillIntervalSeconds: number;
  private ttlSeconds: number;

  constructor(
    storageKey: string,
    max: number,
    refillIntervalSeconds: number,
    ttlSeconds = 86400,
  ) {
    this.storageKey = storageKey;
    this.max = max;
    this.refillIntervalSeconds = refillIntervalSeconds;
    this.ttlSeconds = ttlSeconds;
  }

  public async consume(key: string, cost: number): Promise<boolean> {
    const result = (await client.evalSha(TOKEN_BUCKET_SHA, {
      keys: [`${this.storageKey}:${key}`],
      arguments: [
        this.max.toString(),
        this.refillIntervalSeconds.toString(),
        cost.toString(),
        Math.floor(Date.now() / 1000).toString(),
        this.ttlSeconds.toString(),
      ],
    })) as number[];

    return Boolean(result[0]);
  }
}

I am willing to open a PR for this

@pilcrowonpaper
Copy link
Member

With expiration, I believe it should look like this?

local refill = math.floor((now - refilledAt) / refillIntervalSeconds)
count = math.min(count + refill, max)
if count < cost then
    return {0}
end

count = count - cost
refilledAt = refilledAt + (refill * refillIntervalSeconds)
local expiresAt = refilledAt + (max - count) * refillIntervalSeconds

@avevotsira
Copy link
Author

Yes , that would also work if we want to delete the bucket when it will be full again.

It a bit more complicated to me but much more accurate, If we want that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants