Skip to content

Commit 1687281

Browse files
committed
fix: redis script key expire
1 parent 9ceb307 commit 1687281

File tree

4 files changed

+39
-31
lines changed

4 files changed

+39
-31
lines changed

service/aiproxy/common/rpmlimit/rate-limit.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ local cutoff = current_time - window
2323
2424
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff)
2525
redis.call('ZADD', key, current_time, current_time)
26-
redis.call('PEXPIRE', key, window)
26+
redis.call('PEXPIRE', key, window / 1000)
2727
return redis.call('ZCOUNT', key, cutoff, current_time)
2828
`
2929

service/aiproxy/controller/relay.go

-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ func RelayNotImplemented(c *gin.Context) {
399399
"error": &model.Error{
400400
Message: "API not implemented",
401401
Type: middleware.ErrorTypeAIPROXY,
402-
Param: "",
403402
Code: "api_not_implemented",
404403
},
405404
})

service/aiproxy/monitor/model.go

+35-29
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ var (
3232
clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript)
3333
)
3434

35-
func buildStatsKey(model string, channelID interface{}) string {
36-
return fmt.Sprintf("%s%s%s%v%s", modelKeyPrefix, model, channelKeyPart, channelID, statsKeySuffix)
37-
}
38-
3935
// GetModelErrorRate gets error rate for a specific model across all channels
4036
func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
4137
if !common.RedisEnabled {
@@ -48,11 +44,8 @@ func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
4844
iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
4945
for iter.Next(ctx) {
5046
key := iter.Val()
51-
parts := strings.Split(key, ":")
52-
if len(parts) != 3 || parts[2] != "total_stats" {
53-
continue
54-
}
55-
model := parts[1]
47+
model := strings.TrimPrefix(key, modelKeyPrefix)
48+
model = strings.TrimSuffix(model, modelTotalStatsSuffix)
5649

5750
rate, err := getModelErrorRateScript.Run(
5851
ctx,
@@ -101,7 +94,6 @@ func AddRequest(ctx context.Context, model string, channelID int64, isError bool
10194
errorFlag,
10295
now,
10396
config.GetModelErrorAutoBanRate(),
104-
time.Second.Milliseconds()*15,
10597
canAutoBan(),
10698
).Int64()
10799
if err != nil {
@@ -110,24 +102,42 @@ func AddRequest(ctx context.Context, model string, channelID int64, isError bool
110102
return val == 3, val == 1, nil
111103
}
112104

105+
func buildStatsKey(model string, channelID string) string {
106+
return fmt.Sprintf("%s%s%s%v%s", modelKeyPrefix, model, channelKeyPart, channelID, statsKeySuffix)
107+
}
108+
109+
func getModelChannelID(key string) (string, int64, bool) {
110+
content := strings.TrimPrefix(key, modelKeyPrefix)
111+
content = strings.TrimSuffix(content, statsKeySuffix)
112+
model, channelIDStr, ok := strings.Cut(content, channelKeyPart)
113+
if !ok {
114+
return "", 0, false
115+
}
116+
channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
117+
if err != nil {
118+
return "", 0, false
119+
}
120+
return model, channelID, true
121+
}
122+
113123
// GetChannelModelErrorRates gets error rates for a specific channel
114124
func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
115125
if !common.RedisEnabled {
116126
return map[string]float64{}, nil
117127
}
118128

119129
result := make(map[string]float64)
120-
pattern := buildStatsKey("*", channelID)
130+
pattern := buildStatsKey("*", strconv.FormatInt(channelID, 10))
121131
now := time.Now().UnixMilli()
122132

123133
iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
124134
for iter.Next(ctx) {
125135
key := iter.Val()
126-
parts := strings.Split(key, ":")
127-
if len(parts) != 5 || parts[4] != "stats" {
136+
137+
model, _, ok := getModelChannelID(key)
138+
if !ok {
128139
continue
129140
}
130-
model := parts[1]
131141

132142
rate, err := getChannelModelErrorRateScript.Run(
133143
ctx,
@@ -206,7 +216,8 @@ func GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error)
206216

207217
for iter.Next(ctx) {
208218
key := iter.Val()
209-
model := strings.Split(key, ":")[1]
219+
model := strings.TrimPrefix(key, modelKeyPrefix)
220+
model = strings.TrimSuffix(model, bannedKeySuffix)
210221

211222
channels, err := getBannedChannelsScript.Run(
212223
ctx,
@@ -233,20 +244,15 @@ func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]flo
233244
}
234245

235246
result := make(map[int64]map[string]float64)
236-
pattern := modelKeyPrefix + "*" + channelKeyPart + "*" + statsKeySuffix
247+
pattern := buildStatsKey("*", "*")
237248
now := time.Now().UnixMilli()
238249

239250
iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
240251
for iter.Next(ctx) {
241252
key := iter.Val()
242-
parts := strings.Split(key, ":")
243-
if len(parts) != 5 || parts[4] != "stats" {
244-
continue
245-
}
246253

247-
model := parts[1]
248-
channelID, err := strconv.ParseInt(parts[3], 10, 64)
249-
if err != nil {
254+
model, channelID, ok := getModelChannelID(key)
255+
if !ok {
250256
continue
251257
}
252258

@@ -281,14 +287,14 @@ local channel_id = ARGV[1]
281287
local is_error = tonumber(ARGV[2])
282288
local now_ts = tonumber(ARGV[3])
283289
local max_error_rate = tonumber(ARGV[4])
284-
local statsExpiry = tonumber(ARGV[5])
285-
local can_auto_ban = tonumber(ARGV[6])
290+
local can_auto_ban = tonumber(ARGV[5])
286291
287292
local banned_key = "model:" .. model .. ":banned"
288293
local stats_key = "model:" .. model .. ":channel:" .. channel_id .. ":stats"
289294
local model_stats_key = "model:" .. model .. ":total_stats"
290295
local maxSliceCount = 12
291-
local current_slice = math.floor(now_ts / 10000)
296+
local statsExpiry = maxSliceCount * 10 * 1000
297+
local current_slice = math.floor(now_ts / 10 / 1000)
292298
293299
local function parse_req_err(value)
294300
if not value then return 0, 0 end
@@ -367,7 +373,7 @@ return check_channel_error()
367373
local model_stats_key = KEYS[1]
368374
local now_ts = tonumber(ARGV[1])
369375
local maxSliceCount = 12
370-
local current_slice = math.floor(now_ts / 10000)
376+
local current_slice = math.floor(now_ts / 10 / 1000)
371377
local min_valid_slice = current_slice - maxSliceCount
372378
373379
local function parse_req_err(value)
@@ -389,14 +395,14 @@ for i = 1, #all_slices, 2 do
389395
end
390396
391397
if total_req == 0 then return 0 end
392-
return total_err / total_req
398+
return string.format("%.2f", total_err / total_req)
393399
`
394400

395401
getChannelModelErrorRateLuaScript = `
396402
local stats_key = KEYS[1]
397403
local now_ts = tonumber(ARGV[1])
398404
local maxSliceCount = 12
399-
local current_slice = math.floor(now_ts / 10000)
405+
local current_slice = math.floor(now_ts / 10 / 1000)
400406
local min_valid_slice = current_slice - maxSliceCount
401407
402408
local function parse_req_err(value)

service/aiproxy/relay/controller/dohelper.go

+3
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ func doRequest(a adaptor.Adaptor, c *gin.Context, meta *meta.Meta, req *http.Req
190190
if errors.Is(err, context.DeadlineExceeded) {
191191
return nil, openai.ErrorWrapperWithMessage("do request failed: request timeout", "request_timeout", http.StatusGatewayTimeout)
192192
}
193+
if errors.Is(err, io.EOF) {
194+
return nil, openai.ErrorWrapperWithMessage("do request failed: "+err.Error(), "request_failed", http.StatusServiceUnavailable)
195+
}
193196
return nil, openai.ErrorWrapperWithMessage("do request failed: "+err.Error(), "request_failed", http.StatusBadRequest)
194197
}
195198
return resp, nil

0 commit comments

Comments
 (0)