Skip to content

Commit c79df94

Browse files
committed
fix: rate limit overlimit count
1 parent b94ffa6 commit c79df94

File tree

2 files changed

+59
-17
lines changed

2 files changed

+59
-17
lines changed

service/aiproxy/common/rpmlimit/rate-limit.go

+58-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package rpmlimit
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"strconv"
8+
"strings"
69
"time"
710

811
"github.com/labring/sealos/service/aiproxy/common"
@@ -13,33 +16,54 @@ var inMemoryRateLimiter InMemoryRateLimiter
1316

1417
const (
1518
groupModelRPMKey = "group_model_rpm:%s:%s"
19+
overLimitRPMKey = "over_limit_rpm:%s:%s"
1620
)
1721

1822
var pushRequestScript = `
1923
local key = KEYS[1]
24+
local over_limit_key = KEYS[2]
2025
local window = tonumber(ARGV[1])
2126
local current_time = tonumber(ARGV[2])
27+
local max_requests = tonumber(ARGV[3])
2228
local cutoff = current_time - window
2329
2430
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
25-
redis.call('ZADD', key, current_time, current_time)
26-
redis.call('PEXPIRE', key, window / 1000)
27-
return redis.call('ZCOUNT', key, cutoff, current_time)
31+
redis.call('ZREMRANGEBYSCORE', over_limit_key, '-inf', cutoff)
32+
local count = redis.call('ZCOUNT', key, cutoff, current_time)
33+
local over_limit_count = redis.call('ZCOUNT', over_limit_key, cutoff, current_time)
34+
35+
if count < max_requests then
36+
redis.call('ZADD', key, current_time, current_time)
37+
redis.call('PEXPIRE', key, window / 1000)
38+
count = count + 1
39+
else
40+
redis.call('ZADD', over_limit_key, current_time, current_time)
41+
redis.call('PEXPIRE', over_limit_key, window / 1000)
42+
over_limit_count = over_limit_count + 1
43+
end
44+
45+
return string.format("%d:%d", count, over_limit_count)
2846
`
2947

3048
var getRequestCountScript = `
31-
local pattern = ARGV[1]
32-
local window = tonumber(ARGV[2])
33-
local current_time = tonumber(ARGV[3])
49+
local pattern = KEYS[1]
50+
local over_limit_pattern = KEYS[2]
51+
local window = tonumber(ARGV[1])
52+
local current_time = tonumber(ARGV[2])
3453
local cutoff = current_time - window
3554
36-
local keys = redis.call('KEYS', pattern)
3755
local total = 0
3856
57+
local keys = redis.call('KEYS', pattern)
3958
for _, key in ipairs(keys) do
4059
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
41-
local count = redis.call('ZCOUNT', key, cutoff, current_time)
42-
total = total + count
60+
total = total + redis.call('ZCOUNT', key, cutoff, current_time)
61+
end
62+
63+
local over_limit_keys = redis.call('KEYS', over_limit_pattern)
64+
for _, key in ipairs(over_limit_keys) do
65+
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
66+
total = total + redis.call('ZCOUNT', key, cutoff, current_time)
4367
end
4468
4569
return total
@@ -51,22 +75,26 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) {
5175
}
5276

5377
var pattern string
78+
var overLimitPattern string
5479
if group == "" && model == "" {
5580
pattern = "group_model_rpm:*:*"
81+
overLimitPattern = "over_limit_rpm:*:*"
5682
} else if group == "" {
5783
pattern = "group_model_rpm:*:" + model
84+
overLimitPattern = "over_limit_rpm:*:" + model
5885
} else if model == "" {
5986
pattern = fmt.Sprintf("group_model_rpm:%s:*", group)
87+
overLimitPattern = fmt.Sprintf("over_limit_rpm:%s:*", group)
6088
} else {
6189
pattern = fmt.Sprintf("group_model_rpm:%s:%s", group, model)
90+
overLimitPattern = fmt.Sprintf("over_limit_rpm:%s:%s", group, model)
6291
}
6392

6493
rdb := common.RDB
6594
result, err := rdb.Eval(
6695
ctx,
6796
getRequestCountScript,
68-
[]string{},
69-
pattern,
97+
[]string{pattern, overLimitPattern},
7098
time.Minute.Microseconds(),
7199
time.Now().UnixMicro(),
72100
).Int64()
@@ -77,27 +105,41 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) {
77105
}
78106

79107
func redisRateLimitRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) {
80-
result, err := PushRequest(ctx, group, model, duration)
108+
result, _, err := PushRequest(ctx, group, model, maxRequestNum, duration)
81109
if err != nil {
82110
return false, err
83111
}
84112
return result <= maxRequestNum, nil
85113
}
86114

87-
func PushRequest(ctx context.Context, group, model string, duration time.Duration) (int64, error) {
115+
func PushRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (int64, int64, error) {
88116
result, err := common.RDB.Eval(
89117
ctx,
90118
pushRequestScript,
91119
[]string{
92120
fmt.Sprintf(groupModelRPMKey, group, model),
121+
fmt.Sprintf(overLimitRPMKey, group, model),
93122
},
94123
duration.Microseconds(),
95124
time.Now().UnixMicro(),
96-
).Int64()
125+
maxRequestNum,
126+
).Text()
97127
if err != nil {
98-
return 0, err
128+
return 0, 0, err
99129
}
100-
return result, nil
130+
count, overLimitCount, ok := strings.Cut(result, ":")
131+
if !ok {
132+
return 0, 0, errors.New("invalid result")
133+
}
134+
countInt, err := strconv.ParseInt(count, 10, 64)
135+
if err != nil {
136+
return 0, 0, err
137+
}
138+
overLimitCountInt, err := strconv.ParseInt(overLimitCount, 10, 64)
139+
if err != nil {
140+
return 0, 0, err
141+
}
142+
return countInt, overLimitCountInt, nil
101143
}
102144

103145
func RateLimit(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) {

service/aiproxy/middleware/distributor.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model
101101
return ErrRequestRateLimitExceeded
102102
}
103103
} else if common.RedisEnabled {
104-
_, err := rpmlimit.PushRequest(c.Request.Context(), group.ID, mc.Model, time.Minute)
104+
_, _, err := rpmlimit.PushRequest(c.Request.Context(), group.ID, mc.Model, 1, time.Minute)
105105
if err != nil {
106106
log.Errorf("push request error: %s", err.Error())
107107
}

0 commit comments

Comments
 (0)