@@ -2,7 +2,10 @@ package rpmlimit
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
7
+ "strconv"
8
+ "strings"
6
9
"time"
7
10
8
11
"github.com/labring/sealos/service/aiproxy/common"
@@ -13,33 +16,54 @@ var inMemoryRateLimiter InMemoryRateLimiter
13
16
14
17
const (
15
18
groupModelRPMKey = "group_model_rpm:%s:%s"
19
+ overLimitRPMKey = "over_limit_rpm:%s:%s"
16
20
)
17
21
18
22
var pushRequestScript = `
19
23
local key = KEYS[1]
24
+ local over_limit_key = KEYS[2]
20
25
local window = tonumber(ARGV[1])
21
26
local current_time = tonumber(ARGV[2])
27
+ local max_requests = tonumber(ARGV[3])
22
28
local cutoff = current_time - window
23
29
24
30
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
25
- redis.call('ZADD', key, current_time, current_time)
26
- redis.call('PEXPIRE', key, window / 1000)
27
- return redis.call('ZCOUNT', key, cutoff, current_time)
31
+ redis.call('ZREMRANGEBYSCORE', over_limit_key, '-inf', cutoff)
32
+ local count = redis.call('ZCOUNT', key, cutoff, current_time)
33
+ local over_limit_count = redis.call('ZCOUNT', over_limit_key, cutoff, current_time)
34
+
35
+ if count < max_requests then
36
+ redis.call('ZADD', key, current_time, current_time)
37
+ redis.call('PEXPIRE', key, window / 1000)
38
+ count = count + 1
39
+ else
40
+ redis.call('ZADD', over_limit_key, current_time, current_time)
41
+ redis.call('PEXPIRE', over_limit_key, window / 1000)
42
+ over_limit_count = over_limit_count + 1
43
+ end
44
+
45
+ return string.format("%d:%d", count, over_limit_count)
28
46
`
29
47
30
48
var getRequestCountScript = `
31
- local pattern = ARGV[1]
32
- local window = tonumber(ARGV[2])
33
- local current_time = tonumber(ARGV[3])
49
+ local pattern = KEYS[1]
50
+ local over_limit_pattern = KEYS[2]
51
+ local window = tonumber(ARGV[1])
52
+ local current_time = tonumber(ARGV[2])
34
53
local cutoff = current_time - window
35
54
36
- local keys = redis.call('KEYS', pattern)
37
55
local total = 0
38
56
57
+ local keys = redis.call('KEYS', pattern)
39
58
for _, key in ipairs(keys) do
40
59
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
41
- local count = redis.call('ZCOUNT', key, cutoff, current_time)
42
- total = total + count
60
+ total = total + redis.call('ZCOUNT', key, cutoff, current_time)
61
+ end
62
+
63
+ local over_limit_keys = redis.call('KEYS', over_limit_pattern)
64
+ for _, key in ipairs(over_limit_keys) do
65
+ redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
66
+ total = total + redis.call('ZCOUNT', key, cutoff, current_time)
43
67
end
44
68
45
69
return total
@@ -51,22 +75,26 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) {
51
75
}
52
76
53
77
var pattern string
78
+ var overLimitPattern string
54
79
if group == "" && model == "" {
55
80
pattern = "group_model_rpm:*:*"
81
+ overLimitPattern = "over_limit_rpm:*:*"
56
82
} else if group == "" {
57
83
pattern = "group_model_rpm:*:" + model
84
+ overLimitPattern = "over_limit_rpm:*:" + model
58
85
} else if model == "" {
59
86
pattern = fmt .Sprintf ("group_model_rpm:%s:*" , group )
87
+ overLimitPattern = fmt .Sprintf ("over_limit_rpm:%s:*" , group )
60
88
} else {
61
89
pattern = fmt .Sprintf ("group_model_rpm:%s:%s" , group , model )
90
+ overLimitPattern = fmt .Sprintf ("over_limit_rpm:%s:%s" , group , model )
62
91
}
63
92
64
93
rdb := common .RDB
65
94
result , err := rdb .Eval (
66
95
ctx ,
67
96
getRequestCountScript ,
68
- []string {},
69
- pattern ,
97
+ []string {pattern , overLimitPattern },
70
98
time .Minute .Microseconds (),
71
99
time .Now ().UnixMicro (),
72
100
).Int64 ()
@@ -77,27 +105,41 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) {
77
105
}
78
106
79
107
func redisRateLimitRequest (ctx context.Context , group , model string , maxRequestNum int64 , duration time.Duration ) (bool , error ) {
80
- result , err := PushRequest (ctx , group , model , duration )
108
+ result , _ , err := PushRequest (ctx , group , model , maxRequestNum , duration )
81
109
if err != nil {
82
110
return false , err
83
111
}
84
112
return result <= maxRequestNum , nil
85
113
}
86
114
87
- func PushRequest (ctx context.Context , group , model string , duration time.Duration ) (int64 , error ) {
115
+ func PushRequest (ctx context.Context , group , model string , maxRequestNum int64 , duration time.Duration ) (int64 , int64 , error ) {
88
116
result , err := common .RDB .Eval (
89
117
ctx ,
90
118
pushRequestScript ,
91
119
[]string {
92
120
fmt .Sprintf (groupModelRPMKey , group , model ),
121
+ fmt .Sprintf (overLimitRPMKey , group , model ),
93
122
},
94
123
duration .Microseconds (),
95
124
time .Now ().UnixMicro (),
96
- ).Int64 ()
125
+ maxRequestNum ,
126
+ ).Text ()
97
127
if err != nil {
98
- return 0 , err
128
+ return 0 , 0 , err
99
129
}
100
- return result , nil
130
+ count , overLimitCount , ok := strings .Cut (result , ":" )
131
+ if ! ok {
132
+ return 0 , 0 , errors .New ("invalid result" )
133
+ }
134
+ countInt , err := strconv .ParseInt (count , 10 , 64 )
135
+ if err != nil {
136
+ return 0 , 0 , err
137
+ }
138
+ overLimitCountInt , err := strconv .ParseInt (overLimitCount , 10 , 64 )
139
+ if err != nil {
140
+ return 0 , 0 , err
141
+ }
142
+ return countInt , overLimitCountInt , nil
101
143
}
102
144
103
145
func RateLimit (ctx context.Context , group , model string , maxRequestNum int64 , duration time.Duration ) (bool , error ) {
0 commit comments