Skip to content

Commit 259be05

Browse files
committed
feat: mem model monitor
1 parent 7d4e34b commit 259be05

File tree

2 files changed

+355
-55
lines changed

2 files changed

+355
-55
lines changed

service/aiproxy/monitor/memmodel.go

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
package monitor
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"github.com/labring/sealos/service/aiproxy/common/config"
9+
)
10+
11+
var memModelMonitor *MemModelMonitor
12+
13+
func init() {
14+
memModelMonitor = NewMemModelMonitor()
15+
}
16+
17+
const (
18+
timeWindow = 10 * time.Second
19+
maxSliceCount = 12
20+
banDuration = 5 * time.Minute
21+
minRequestCount = 20
22+
)
23+
24+
type MemModelMonitor struct {
25+
mu sync.RWMutex
26+
models map[string]*ModelData
27+
}
28+
29+
type ModelData struct {
30+
channels map[int64]*ChannelStats
31+
totalStats *TimeWindowStats
32+
}
33+
34+
type ChannelStats struct {
35+
timeWindows *TimeWindowStats
36+
bannedUntil time.Time
37+
}
38+
39+
type TimeWindowStats struct {
40+
slices []*timeSlice
41+
mu sync.Mutex
42+
}
43+
44+
type timeSlice struct {
45+
windowStart time.Time
46+
requests int
47+
errors int
48+
}
49+
50+
func NewTimeWindowStats() *TimeWindowStats {
51+
return &TimeWindowStats{
52+
slices: make([]*timeSlice, 0, maxSliceCount),
53+
}
54+
}
55+
56+
func NewMemModelMonitor() *MemModelMonitor {
57+
return &MemModelMonitor{
58+
models: make(map[string]*ModelData),
59+
}
60+
}
61+
62+
func (m *MemModelMonitor) AddRequest(model string, channelID int64, isError, tryBan bool) (beyondThreshold, banExecution bool) {
63+
m.mu.Lock()
64+
defer m.mu.Unlock()
65+
66+
now := time.Now()
67+
68+
var modelData *ModelData
69+
var exists bool
70+
if modelData, exists = m.models[model]; !exists {
71+
modelData = &ModelData{
72+
channels: make(map[int64]*ChannelStats),
73+
totalStats: NewTimeWindowStats(),
74+
}
75+
m.models[model] = modelData
76+
}
77+
78+
var channel *ChannelStats
79+
if channel, exists = modelData.channels[channelID]; !exists {
80+
channel = &ChannelStats{
81+
timeWindows: NewTimeWindowStats(),
82+
}
83+
modelData.channels[channelID] = channel
84+
}
85+
86+
modelData.totalStats.AddRequest(now, isError)
87+
channel.timeWindows.AddRequest(now, isError)
88+
89+
return m.checkAndBan(now, channel, tryBan)
90+
}
91+
92+
func (m *MemModelMonitor) checkAndBan(now time.Time, channel *ChannelStats, tryBan bool) (beyondThreshold, banExecution bool) {
93+
canBan := config.GetEnableModelErrorAutoBan()
94+
if tryBan && canBan {
95+
if channel.bannedUntil.After(now) {
96+
return false, false
97+
}
98+
channel.bannedUntil = now.Add(banDuration)
99+
return false, true
100+
}
101+
102+
req, err := channel.timeWindows.GetStats(maxSliceCount)
103+
if req < minRequestCount {
104+
return false, false
105+
}
106+
107+
errorRate := float64(err) / float64(req)
108+
if errorRate >= config.GetModelErrorAutoBanRate() {
109+
if !canBan || channel.bannedUntil.After(now) {
110+
return true, false
111+
}
112+
channel.bannedUntil = now.Add(banDuration)
113+
return false, true
114+
}
115+
channel.bannedUntil = time.Time{}
116+
return false, false
117+
}
118+
119+
func getErrorRateFromStats(stats *TimeWindowStats) float64 {
120+
req, err := stats.GetStats(maxSliceCount)
121+
if req < minRequestCount {
122+
return 0
123+
}
124+
return float64(err) / float64(req)
125+
}
126+
127+
func (m *MemModelMonitor) GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
128+
m.mu.RLock()
129+
defer m.mu.RUnlock()
130+
131+
result := make(map[string]float64)
132+
for model, data := range m.models {
133+
result[model] = getErrorRateFromStats(data.totalStats)
134+
}
135+
return result, nil
136+
}
137+
138+
func (m *MemModelMonitor) GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
139+
m.mu.RLock()
140+
defer m.mu.RUnlock()
141+
142+
result := make(map[int64]float64)
143+
if data, exists := m.models[model]; exists {
144+
for channelID, channel := range data.channels {
145+
result[channelID] = getErrorRateFromStats(channel.timeWindows)
146+
}
147+
}
148+
return result, nil
149+
}
150+
151+
func (m *MemModelMonitor) GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
152+
m.mu.RLock()
153+
defer m.mu.RUnlock()
154+
155+
result := make(map[string]float64)
156+
for model, data := range m.models {
157+
if channel, exists := data.channels[channelID]; exists {
158+
result[model] = getErrorRateFromStats(channel.timeWindows)
159+
}
160+
}
161+
return result, nil
162+
}
163+
164+
func (m *MemModelMonitor) GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
165+
m.mu.RLock()
166+
defer m.mu.RUnlock()
167+
168+
result := make(map[int64]map[string]float64)
169+
for model, data := range m.models {
170+
for channelID, channel := range data.channels {
171+
if _, exists := result[channelID]; !exists {
172+
result[channelID] = make(map[string]float64)
173+
}
174+
result[channelID][model] = getErrorRateFromStats(channel.timeWindows)
175+
}
176+
}
177+
return result, nil
178+
}
179+
180+
func (m *MemModelMonitor) GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
181+
m.mu.RLock()
182+
defer m.mu.RUnlock()
183+
184+
var banned []int64
185+
if data, exists := m.models[model]; exists {
186+
now := time.Now()
187+
for channelID, channel := range data.channels {
188+
if channel.bannedUntil.After(now) {
189+
banned = append(banned, channelID)
190+
} else {
191+
channel.bannedUntil = time.Time{}
192+
}
193+
}
194+
}
195+
return banned, nil
196+
}
197+
198+
func (m *MemModelMonitor) GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
199+
m.mu.RLock()
200+
defer m.mu.RUnlock()
201+
202+
result := make(map[string][]int64)
203+
now := time.Now()
204+
205+
for model, data := range m.models {
206+
for channelID, channel := range data.channels {
207+
if channel.bannedUntil.After(now) {
208+
if _, exists := result[model]; !exists {
209+
result[model] = []int64{}
210+
}
211+
result[model] = append(result[model], channelID)
212+
} else {
213+
channel.bannedUntil = time.Time{}
214+
}
215+
}
216+
}
217+
return result, nil
218+
}
219+
220+
func (m *MemModelMonitor) ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
221+
m.mu.Lock()
222+
defer m.mu.Unlock()
223+
224+
if data, exists := m.models[model]; exists {
225+
delete(data.channels, int64(channelID))
226+
}
227+
return nil
228+
}
229+
230+
func (m *MemModelMonitor) ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
231+
m.mu.Lock()
232+
defer m.mu.Unlock()
233+
234+
for _, data := range m.models {
235+
delete(data.channels, int64(channelID))
236+
}
237+
return nil
238+
}
239+
240+
func (m *MemModelMonitor) ClearAllModelErrors(ctx context.Context) error {
241+
m.mu.Lock()
242+
defer m.mu.Unlock()
243+
244+
m.models = make(map[string]*ModelData)
245+
return nil
246+
}
247+
248+
func (t *TimeWindowStats) AddRequest(now time.Time, isError bool) {
249+
t.mu.Lock()
250+
defer t.mu.Unlock()
251+
252+
currentWindow := now.Truncate(timeWindow)
253+
254+
cutoff := now.Add(-timeWindow * time.Duration(maxSliceCount))
255+
validSlices := t.slices[:0]
256+
for _, s := range t.slices {
257+
if s.windowStart.After(cutoff) || s.windowStart.Equal(cutoff) {
258+
validSlices = append(validSlices, s)
259+
}
260+
}
261+
t.slices = validSlices
262+
263+
var slice *timeSlice
264+
for i := range t.slices {
265+
if t.slices[i].windowStart.Equal(currentWindow) {
266+
slice = t.slices[i]
267+
break
268+
}
269+
}
270+
271+
if slice == nil {
272+
slice = &timeSlice{windowStart: currentWindow}
273+
t.slices = append(t.slices, slice)
274+
}
275+
276+
slice.requests++
277+
if isError {
278+
slice.errors++
279+
}
280+
}
281+
282+
func (t *TimeWindowStats) GetStats(maxSlice int) (totalReq, totalErr int) {
283+
t.mu.Lock()
284+
defer t.mu.Unlock()
285+
286+
cutoff := time.Now().Add(-timeWindow * time.Duration(maxSlice))
287+
288+
validSlices := t.slices[:0]
289+
for _, s := range t.slices {
290+
if s.windowStart.After(cutoff) || s.windowStart.Equal(cutoff) {
291+
validSlices = append(validSlices, s)
292+
totalReq += s.requests
293+
totalErr += s.errors
294+
}
295+
}
296+
t.slices = validSlices
297+
return
298+
}

0 commit comments

Comments
 (0)