From 317510e577c454c95f5541bcae3ddd88c400bb45 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Wed, 20 Dec 2023 18:02:57 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=AF=8F=E6=97=A5?= =?UTF-8?q?=E5=85=8D=E8=B4=B9=E4=BD=BF=E7=94=A8=E9=A2=9D=E5=BA=A6=EF=BC=88?= =?UTF-8?q?=E4=B8=8D=E9=9C=80=E8=A6=81=E7=99=BB=E5=BD=95=E5=B0=B1=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E8=AE=BF=E9=97=AE=E7=9A=84=E9=A2=9D=E5=BA=A6=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 9 +++ config/config.go | 14 ++++ config/flag.go | 5 ++ internal/queue/group_chat.go | 2 +- pkg/ai/chat/chat.go | 10 ++- pkg/ai/chat/chat_test.go | 6 +- pkg/rate/limiter.go | 1 + server/controllers/openai.go | 120 +++++++++++++++++++++++++---------- server/provider.go | 7 +- 9 files changed, 134 insertions(+), 40 deletions(-) diff --git a/config.yaml b/config.yaml index 2cbcb78..2910edb 100644 --- a/config.yaml +++ b/config.yaml @@ -435,3 +435,12 @@ virtual-model-beichou-prompt: "" #price-table-file: /data/webroot/aidea-server/etc/coins-table.yaml price-table-file: "" +######## 免费模型 ######## +# 是否启用免费聊天功能,启用后,未登录可以免费使用部分模型 +free-chat-enabled: false +# 每日免费次数,基于客户端 IP 限制 +free-chat-daily-limit: 5 +# 每日全局免费次数,总的限制次数,不管是哪个 IP +free-chat-daily-global-limit: 1000 +# 免费模型,所有的请求都会被替换为该模型 +free-chat-model: gpt-3.5-turbo \ No newline at end of file diff --git a/config/config.go b/config/config.go index c2331fe..0dcf9d3 100644 --- a/config/config.go +++ b/config/config.go @@ -261,6 +261,15 @@ type Config struct { FontPath string `json:"font_path" yaml:"font_path"` // 服务状态页面 ServiceStatusPage string `json:"service_status_page" yaml:"service_status_page"` + + // 免费 Chat 请求 (仅限 IOS) + FreeChatEnabled bool `json:"free_chat_enabled" yaml:"free_chat_enabled"` + // 免费 Chat 每日限制(每 IP) + FreeChatDailyLimit int `json:"free_chat_daily_limit" yaml:"free_chat_daily_limit"` + // 免费 Chat 每日全局限制(不区分 IP) + FreeChatDailyGlobalLimit int `json:"free_chat_daily_global_limit" yaml:"free_chat_daily_global_limit"` + // 免费 Chat 模型 + FreeChatModel string `json:"free_chat_model" yaml:"free_chat_model"` } func (conf *Config) SupportProxy() bool { @@ -531,6 +540,11 @@ func Register(ins *app.App) { FontPath: ctx.String("font-path"), ServiceStatusPage: ctx.String("service-status-page"), + + FreeChatEnabled: ctx.Bool("free-chat-enabled"), + FreeChatDailyLimit: ctx.Int("free-chat-daily-limit"), + FreeChatDailyGlobalLimit: ctx.Int("free-chat-daily-global-limit"), + FreeChatModel: ctx.String("free-chat-model"), } }) } diff --git a/config/flag.go b/config/flag.go index 09e8a46..79ebc47 100644 --- a/config/flag.go +++ b/config/flag.go @@ -217,4 +217,9 @@ func initCmdFlags(ins *app.App) { ins.AddStringFlag("font-path", "", "字体文件路径") ins.AddStringFlag("service-status-page", "", "服务状态页面,留空则不启用服务状态页面") + + ins.AddBoolFlag("free-chat-enabled", "是否启用免费聊天功能,启用后,未登录可以免费使用部分模型") + ins.AddIntFlag("free-chat-daily-limit", 5, "每日免费次数,基于客户端 IP 限制") + ins.AddIntFlag("free-chat-daily-global-limit", 1000, "每日全局免费次数,总的限制次数,不管是哪个 IP") + ins.AddStringFlag("free-chat-model", "gpt-3.5-turbo", "免费模型,所有的请求都会被替换为该模型") } diff --git a/internal/queue/group_chat.go b/internal/queue/group_chat.go index eb03626..96f96a1 100644 --- a/internal/queue/group_chat.go +++ b/internal/queue/group_chat.go @@ -110,7 +110,7 @@ func BuildGroupChatHandler(conf *config.Config, ct chat.Chat, rep *repo2.Reposit req, _, err := (chat.Request{ Model: payload.ModelID, Messages: payload.ContextMessages, - }).Init().Fix(ct, 5) + }).Init().Fix(ct, 5, 1024*200) if err != nil { panic(fmt.Errorf("fix chat request failed: %w", err)) } diff --git a/pkg/ai/chat/chat.go b/pkg/ai/chat/chat.go index 5419b65..6b8b91c 100644 --- a/pkg/ai/chat/chat.go +++ b/pkg/ai/chat/chat.go @@ -161,18 +161,24 @@ func (req Request) Init() Request { } // Fix 修复请求内容,注意:上下文长度修复后,最终的上下文数量不包含 system 消息和用户最后一条消息 -func (req Request) Fix(chat Chat, maxContextLength int64) (*Request, int64, error) { +func (req Request) Fix(chat Chat, maxContextLength int64, maxTokenCount int) (*Request, int64, error) { // 自动缩减上下文长度至满足模型要求的最大长度,尽可能避免出现超过模型上下文长度的问题 systemMessages := array.Filter(req.Messages, func(item Message, _ int) bool { return item.Role == "system" }) systemMessageLen, _ := MessageTokenCount(systemMessages, req.Model) + // 模型允许的 Tokens 数量和请求参数指定的 Tokens 数量,取最小值 + modelTokenLimit := chat.MaxContextLength(req.Model) - systemMessageLen + if modelTokenLimit < maxTokenCount { + maxTokenCount = modelTokenLimit + } + messages, inputTokens, err := ReduceMessageContext( ReduceMessageContextUpToContextWindow( array.Filter(req.Messages, func(item Message, _ int) bool { return item.Role != "system" }), int(maxContextLength), ), req.Model, - chat.MaxContextLength(req.Model)-systemMessageLen, + maxTokenCount, ) if err != nil { return nil, 0, errors.New("超过模型最大允许的上下文长度限制,请尝试“新对话”或缩短输入内容长度") diff --git a/pkg/ai/chat/chat_test.go b/pkg/ai/chat/chat_test.go index 35ca7d5..0209143 100644 --- a/pkg/ai/chat/chat_test.go +++ b/pkg/ai/chat/chat_test.go @@ -36,19 +36,19 @@ func TestRequestFix(t *testing.T) { }.Init() { - fixed, _, err := req.Fix(ChatTestClient{}, 0) + fixed, _, err := req.Fix(ChatTestClient{}, 0, 1024*200) assert.NoError(t, err) assert.Equal(t, 2, len(fixed.Messages)) } { - fixed, _, err := req.Fix(ChatTestClient{}, 1) + fixed, _, err := req.Fix(ChatTestClient{}, 1, 1024*200) assert.NoError(t, err) assert.Equal(t, 4, len(fixed.Messages)) } { - fixed, _, err := req.Fix(ChatTestClient{}, 2) + fixed, _, err := req.Fix(ChatTestClient{}, 2, 1024*200) assert.NoError(t, err) assert.Equal(t, 6, len(fixed.Messages)) } diff --git a/pkg/rate/limiter.go b/pkg/rate/limiter.go index ecdf7c6..61ea58c 100644 --- a/pkg/rate/limiter.go +++ b/pkg/rate/limiter.go @@ -11,6 +11,7 @@ import ( ) var ErrRateLimitExceeded = errors.New("请求频率过高,请稍后再试") +var ErrDailyFreeLimitExceeded = errors.New("超过每日免费次数") func NewLimiter(rdb *redis.Client) *redis_rate.Limiter { return redis_rate.NewLimiter(rdb) diff --git a/server/controllers/openai.go b/server/controllers/openai.go index 3bc5077..2c9d4c4 100644 --- a/server/controllers/openai.go +++ b/server/controllers/openai.go @@ -195,16 +195,46 @@ func (m FinalMessage) ToJSON() string { // Chat 聊天接口,接口参数参考 https://platform.openai.com/docs/api-reference/chat/create // 该接口会返回一个 SSE 流,接口参数 stream 总是为 true(忽略客户端设置) -func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user *auth.User, quotaRepo *repo2.QuotaRepo, w http.ResponseWriter, client *auth.ClientInfo) { +func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user *auth.UserOptional, quotaRepo *repo2.QuotaRepo, w http.ResponseWriter, client *auth.ClientInfo) { + if user.User == nil && ctl.conf.FreeChatEnabled { + // 匿名用户访问 + user.User = &auth.User{ + ID: 0, + Name: "anonymous", + } + } + + if user.User == nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error": "用户未登录,请先登录后再试"}`)) + return + } + + // 流控,避免单一用户过度使用 + if err := ctl.rateLimitPass(ctx, client, user.User); err != nil { + if errors.Is(err, rate.ErrDailyFreeLimitExceeded) { + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusTooManyRequests) + } + _, _ = w.Write([]byte(fmt.Sprintf(`{"error": %s}`, strconv.Quote(err.Error())))) + return + } + sw, req, err := streamwriter.New[chat2.Request]( webCtx.Input("ws") == "true", ctl.conf.EnableCORS, webCtx.Request().Raw(), w, ) if err != nil { - log.F(log.M{"user": user.ID, "client": client}).Errorf("create stream writer failed: %s", err) + log.F(log.M{"user": user.User.ID, "client": client}).Errorf("create stream writer failed: %s", err) return } defer sw.Close() + // 匿名用户,使用免费模型代替 + if user.User.ID == 0 && ctl.conf.FreeChatModel != "" { + req.Model = ctl.conf.FreeChatModel + } + // 请求参数预处理 var inputTokenCount, maxContextLen int64 @@ -219,8 +249,8 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user inputTokenCount = int64(icnt) } else { - maxContextLen = ctl.loadRoomContextLen(ctx, req.RoomID, user.ID) - req, inputTokenCount, err = req.Fix(ctl.chat, maxContextLen) + maxContextLen = ctl.loadRoomContextLen(ctx, req.RoomID, user.User.ID) + req, inputTokenCount, err = req.Fix(ctl.chat, maxContextLen, ternary.If(user.User.ID > 0, 1000*200, 1000)) if err != nil { misc.NoError(sw.WriteErrorStream(err, http.StatusBadRequest)) return @@ -234,16 +264,18 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user return } - // 基于模型的流控,避免单一模型用户过度使用 - if err := ctl.rateLimitPass(ctx, user, req, sw); err != nil { - return - } - // 免费模型 // 获取当前用户剩余的智慧果数量,如果不足,则返回错误 - leftCount, maxFreeCount := ctl.userSrv.FreeChatRequestCounts(ctx, user.ID, req.Model) + var leftCount, maxFreeCount int + if user.User.ID > 0 { + leftCount, maxFreeCount = ctl.userSrv.FreeChatRequestCounts(ctx, user.User.ID, req.Model) + } else { + // 匿名用户,每次都是免费的,不限制次数,通过流控来限制访问 + leftCount, maxFreeCount = 1, 0 + } + if leftCount <= 0 { - quota, needCoins, err := ctl.queryChatQuota(ctx, quotaRepo, user, sw, webCtx, req, inputTokenCount, maxFreeCount) + quota, needCoins, err := ctl.queryChatQuota(ctx, quotaRepo, user.User, sw, webCtx, req, inputTokenCount, maxFreeCount) if err != nil { return } @@ -260,20 +292,20 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user } // 冻结本次所需要的智慧果 - if err := ctl.userSrv.FreezeUserQuota(ctx, user.ID, needCoins); err != nil { - log.F(log.M{"user_id": user.ID, "quota": needCoins}).Errorf("freeze user quota failed: %s", err) + if err := ctl.userSrv.FreezeUserQuota(ctx, user.User.ID, needCoins); err != nil { + log.F(log.M{"user_id": user.User.ID, "quota": needCoins}).Errorf("freeze user quota failed: %s", err) } else { defer func(ctx context.Context) { // 解冻智慧果 - if err := ctl.userSrv.UnfreezeUserQuota(ctx, user.ID, needCoins); err != nil { - log.F(log.M{"user_id": user.ID, "quota": needCoins}).Errorf("unfreeze user quota failed: %s", err) + if err := ctl.userSrv.UnfreezeUserQuota(ctx, user.User.ID, needCoins); err != nil { + log.F(log.M{"user_id": user.User.ID, "quota": needCoins}).Errorf("unfreeze user quota failed: %s", err) } }(ctx) } } // 内容安全检测 - if err := ctl.contentSafety(req, user, sw); err != nil { + if err := ctl.contentSafety(req, user.User, sw); err != nil { return } @@ -283,7 +315,7 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user startTime := time.Now() defer func() { log.F(log.M{ - "user_id": user.ID, + "user_id": user.User.ID, "client": client, "room_id": req.RoomID, "elapse": time.Since(startTime).Seconds(), @@ -298,10 +330,10 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user }() // 写入用户消息 - questionID := ctl.saveChatQuestion(ctx, user, req) + questionID := ctl.saveChatQuestion(ctx, user.User, req) // 发起聊天请求并返回 SSE/WS 流 - replyText, err := ctl.handleChat(ctx, req, user, sw, webCtx, questionID, 0) + replyText, err := ctl.handleChat(ctx, req, user.User, sw, webCtx, questionID, 0) if errors.Is(err, ErrChatResponseHasSent) { return } @@ -312,9 +344,9 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user if errors.Is(err, ErrChatResponseEmpty) || (errors.Is(err, ErrChatResponseGapTimeout) && replyText == "") { // 如果用户等待时间超过 60s,则不再重试,避免用户等待时间过长 if startTime.Add(60 * time.Second).After(time.Now()) { - log.F(log.M{"req": req, "user_id": user.ID}).Warningf("聊天响应为空,尝试再次请求,模型:%s", req.Model) + log.F(log.M{"req": req, "user_id": user.User.ID}).Warningf("聊天响应为空,尝试再次请求,模型:%s", req.Model) - replyText, err = ctl.handleChat(ctx, req, user, sw, webCtx, questionID, 1) + replyText, err = ctl.handleChat(ctx, req, user.User, sw, webCtx, questionID, 1) if errors.Is(err, ErrChatResponseHasSent) { return } @@ -323,7 +355,7 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user chatErrorMessage := ternary.IfLazy(err == nil, func() string { return "" }, func() string { return err.Error() }) if chatErrorMessage != "" { - log.F(log.M{"req": req, "user_id": user.ID, "reply": replyText, "elapse": time.Since(startTime).Seconds()}). + log.F(log.M{"req": req, "user_id": user.User.ID, "reply": replyText, "elapse": time.Since(startTime).Seconds()}). Errorf("聊天失败,模型:%s,错误:%s", req.Model, chatErrorMessage) } @@ -335,14 +367,14 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user defer cancel() // 写入用户消息 - answerID := ctl.saveChatAnswer(ctx, user, replyText, quotaConsumed, realTokenConsumed, req, questionID, chatErrorMessage) + answerID := ctl.saveChatAnswer(ctx, user.User, replyText, quotaConsumed, realTokenConsumed, req, questionID, chatErrorMessage) if errors.Is(ErrChatResponseEmpty, err) { misc.NoError(sw.WriteErrorStream(err, http.StatusInternalServerError)) } else { if !ctl.apiMode { // final 消息为定制消息,用于告诉 AIdea 客户端当前的资源消耗情况以及服务端信息 - finalWord := ctl.buildFinalSystemMessage(questionID, answerID, user, quotaConsumed, realTokenConsumed, req, maxContextLen, chatErrorMessage) + finalWord := ctl.buildFinalSystemMessage(questionID, answerID, user.User, quotaConsumed, realTokenConsumed, req, maxContextLen, chatErrorMessage) misc.NoError(sw.WriteStream(finalWord)) } } @@ -354,9 +386,9 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - if err := ctl.userSrv.UpdateFreeChatCount(ctx, user.ID, req.Model); err != nil { + if err := ctl.userSrv.UpdateFreeChatCount(ctx, user.User.ID, req.Model); err != nil { log.WithFields(log.Fields{ - "user_id": user.ID, + "user_id": user.User.ID, "model": req.Model, }).Errorf("update free chat count failed: %s", err) } @@ -369,7 +401,7 @@ func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - if err := quotaRepo.QuotaConsume(ctx, user.ID, quotaConsumed, repo2.NewQuotaUsedMeta("chat", req.Model)); err != nil { + if err := quotaRepo.QuotaConsume(ctx, user.User.ID, quotaConsumed, repo2.NewQuotaUsedMeta("chat", req.Model)); err != nil { log.Errorf("used quota add failed: %s", err) } }() @@ -529,7 +561,7 @@ func (*OpenAIController) buildFinalSystemMessage( Error: chatErrorMessage, } - if len(req.Messages) >= int(maxContextLen*2) { + if len(req.Messages) >= int(maxContextLen*2)-1 { if req.RoomID <= 1 { finalMsg.Info = fmt.Sprintf("本次请求消耗了 %d 个 Token。\n\nAI 记住的对话信息越多,消耗的 Token 和智慧果也越多。\n\n如果新问题和之前的对话无关,请在“聊一聊”页面创建新对话。", realTokenConsumed) } else { @@ -581,16 +613,38 @@ func (ctl *OpenAIController) queryChatQuota( return quota, coins.GetOpenAITextCoins(req.ResolveCalFeeModel(ctl.conf), inputTokenCount) + 3, nil } -func (ctl *OpenAIController) rateLimitPass(ctx context.Context, user *auth.User, req *chat2.Request, sw *streamwriter.StreamWriter) error { +func (ctl *OpenAIController) rateLimitPass(ctx context.Context, client *auth.ClientInfo, user *auth.User) error { if ctl.conf.EnableModelRateLimit { - if err := ctl.limiter.Allow(ctx, fmt.Sprintf("chat-limit:u:%d:m:%s:minute", user.ID, req.Model), redis_rate.PerMinute(5)); err != nil { + if err := ctl.limiter.Allow(ctx, fmt.Sprintf("chat-limit:u:%d:minute", user.ID), redis_rate.PerMinute(10)); err != nil { if errors.Is(err, rate.ErrRateLimitExceeded) { - misc.NoError(sw.WriteErrorStream(errors.New("操作频率过高,请稍后再试"), http.StatusBadRequest)) return rate.ErrRateLimitExceeded } - log.F(log.M{"user_id": user.ID, "req": req}).Errorf("check rate limit failed: %s", err) + log.F(log.M{"user_id": user.ID}).Errorf("聊天请求频率过高: %s", err) + } + } + + // 匿名用户每日免费次数限制 + if ctl.conf.FreeChatEnabled && user.ID == 0 { + lim := redis_rate.Limit{Rate: ctl.conf.FreeChatDailyLimit, Burst: ctl.conf.FreeChatDailyLimit, Period: time.Hour * 24} + if err := ctl.limiter.Allow(ctx, fmt.Sprintf("chat-limit:anonymous:%s:daily", client.IP), lim); err != nil { + log.F(log.M{"ip": client.IP}).Errorf("今日免费次数已用完(IP): %s", err) + return rate.ErrDailyFreeLimitExceeded } + + // 全局限制免费次数,这里是总次数,不区分用户 + if ctl.conf.FreeChatDailyGlobalLimit > 0 { + dailyGlobalLimitKey := fmt.Sprintf("chat-limit:free:daily:%s", time.Now().Format("2006-01-02")) + todayCount, _ := ctl.limiter.OperationCount(ctx, dailyGlobalLimitKey) + if todayCount > int64(ctl.conf.FreeChatDailyGlobalLimit) { + log.F(log.M{"ip": client.IP}).Errorf("今日免费次数已用完(全局)") + return rate.ErrDailyFreeLimitExceeded + } + + _ = ctl.limiter.OperationIncr(ctx, dailyGlobalLimitKey, time.Hour*24) + } + + log.F(log.M{"ip": client.IP}).Debugf("free request") } return nil @@ -674,7 +728,7 @@ func (ctl *OpenAIController) saveChatQuestion(ctx context.Context, user *auth.Us func (ctl *OpenAIController) loadRoomContextLen(ctx context.Context, roomID int64, userID int64) int64 { var maxContextLength int64 = 3 - if roomID > 0 { + if roomID > 0 && userID > 0 { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/server/provider.go b/server/provider.go index 78fdaba..94f66fc 100644 --- a/server/provider.go +++ b/server/provider.go @@ -79,8 +79,8 @@ func routes(resolver infra.Resolver, router web.Router, mw web.RequestMiddleware // 需要鉴权的 URLs needAuthPrefix := []string{ - "/v1/chat", // OpenAI chat "/v1/audio", // OpenAI audio to text + "/v1/images", // OpenAI image generation "/v1/group-chat", // 群聊 "/v1/users", // 用户管理 "/v1/api-keys", // API Key 管理 @@ -411,6 +411,11 @@ func BuildCounterVec(namespace, name, help string, tags []string) *prometheus.Co func readFromWebContext(webCtx web.Context, key string) string { val := webCtx.Input(key) if val != "" { + // TODO 临时处理,从请求参数中读取 Authorization 头需要添加 Bearer 前缀,否则后续的一些鉴权逻辑会有问题 + if strings.ToLower(key) == "authorization" { + return "Bearer " + val + } + return val } From de57630629cf1c4bcf7a3fe0f776e62c58601af0 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Mon, 25 Dec 2023 14:56:46 +0800 Subject: [PATCH 02/13] update --- pkg/ai/stabilityai/video.go | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 pkg/ai/stabilityai/video.go diff --git a/pkg/ai/stabilityai/video.go b/pkg/ai/stabilityai/video.go new file mode 100644 index 0000000..927a30c --- /dev/null +++ b/pkg/ai/stabilityai/video.go @@ -0,0 +1,49 @@ +package stabilityai + +type VideoRequest struct { + // ImagePath The source image used in the video generation process. Please ensure that the source image is in the correct format and dimensions. + // Supported Formats: + // - image/jpeg + // - image/png + // Supported Dimensions: + // - 1024x576 + // - 576x1024 + // - 768x768 + ImagePath string `json:"image_path,omitempty"` + // Seed A specific value that is used to guide the 'randomness' of the generation. + // (Omit this parameter or pass 0 to use a random seed.) + // number [ 0 .. 2147483648 ], default 0 + Seed int `json:"seed,omitempty"` + // CfgScale How strongly the video sticks to the original image. + // Use lower values to allow the model more freedom to make changes and higher values to correct motion distortions. + // number [ 0 .. 10 ], default 2.5 + CfgScale int `json:"cfg_scale,omitempty"` + // MotionBucketID Lower values generally result in less motion in the output video, + // while higher values generally result in more motion. + // This parameter corresponds to the motion_bucket_id parameter from the paper. + // number [ 1 .. 255 ], default 40 + MotionBucketID int `json:"motion_bucket_id,omitempty"` +} + +type VideoTaskResponse struct { + ID string `json:"id,omitempty"` +} + +type VideoResponse struct { + // 200 + + // Video The generated video. + Video string `json:"video,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Seed int `json:"seed,omitempty"` + + // 202 + ID string `json:"id,omitempty"` + // Status: in-progress + Status string `json:"status,omitempty"` +} + +type VideoError struct { + Name string `json:"name,omitempty"` + Errors []string `json:"errors,omitempty"` +} From 11394126007a8a3de2e6c9b476a01378906a952b Mon Sep 17 00:00:00 2001 From: mylxsw Date: Mon, 25 Dec 2023 15:03:49 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20Claude=202=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=90=8D=E7=A7=B0=E4=B8=BA=202.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/ai/anthropic/anthropic.go | 13 +++++++++++++ pkg/ai/anthropic/anthropic_test.go | 2 +- pkg/ai/chat/models.go | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pkg/ai/anthropic/anthropic.go b/pkg/ai/anthropic/anthropic.go index c15890e..1ba2c0b 100644 --- a/pkg/ai/anthropic/anthropic.go +++ b/pkg/ai/anthropic/anthropic.go @@ -37,8 +37,20 @@ func New(serverURL, apiKey string, client *http.Client) *Anthropic { return &Anthropic{apiKey: apiKey, serverURL: serverURL, client: client} } +func (ai *Anthropic) resolveModel(model Model) Model { + switch model { + case "claude-instant-1": + return "claude-instant-1.2" + case "claude-2": + return "claude-2.1" + default: + return model + } +} + func (ai *Anthropic) Chat(ctx context.Context, req Request) (*Response, error) { req.Stream = false + req.Model = ai.resolveModel(req.Model) if req.MaxTokensToSample <= 0 { req.MaxTokensToSample = 4000 } @@ -79,6 +91,7 @@ func (ai *Anthropic) Chat(ctx context.Context, req Request) (*Response, error) { func (ai *Anthropic) ChatStream(ctx context.Context, req Request) (<-chan Response, error) { req.Stream = true + req.Model = ai.resolveModel(req.Model) if req.MaxTokensToSample <= 0 { req.MaxTokensToSample = 4000 } diff --git a/pkg/ai/anthropic/anthropic_test.go b/pkg/ai/anthropic/anthropic_test.go index 42d1295..114cd49 100644 --- a/pkg/ai/anthropic/anthropic_test.go +++ b/pkg/ai/anthropic/anthropic_test.go @@ -46,7 +46,7 @@ func TestAnthropic_ChatStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - resp, err := client.ChatStream(ctx, anthropic.NewRequest(anthropic.ModelClaudeInstant, []anthropic.Message{ + resp, err := client.ChatStream(ctx, anthropic.NewRequest(anthropic.ModelClaude2, []anthropic.Message{ { Role: "user", Content: "你是一名占卜师,我给你名字,你帮我占卜运势", diff --git a/pkg/ai/chat/models.go b/pkg/ai/chat/models.go index 68eb1aa..baffa46 100644 --- a/pkg/ai/chat/models.go +++ b/pkg/ai/chat/models.go @@ -553,7 +553,7 @@ func anthropicModels(conf *config.Config) []Model { }, { ID: "Anthropic:" + string(anthropic.ModelClaude2), - Name: "Claude 2.0", + Name: "Claude 2.1", ShortName: "Claude2", Description: "Anthropic's most powerful model. Particularly good at creative writing.", Category: "Anthropic", From 63d2d804926057b0e4c1ddfc1644900f95f94bff Mon Sep 17 00:00:00 2001 From: mylxsw Date: Tue, 26 Dec 2023 12:42:39 +0800 Subject: [PATCH 04/13] stabilityai image-to-video sdk --- pkg/ai/stabilityai/video.go | 148 +++++++++++++++++++++++++++++++ pkg/ai/stabilityai/video_test.go | 96 ++++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 pkg/ai/stabilityai/video_test.go diff --git a/pkg/ai/stabilityai/video.go b/pkg/ai/stabilityai/video.go index 927a30c..b4dc971 100644 --- a/pkg/ai/stabilityai/video.go +++ b/pkg/ai/stabilityai/video.go @@ -1,5 +1,24 @@ package stabilityai +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/hashicorp/go-uuid" + "github.com/mylxsw/aidea-server/pkg/uploader" + "github.com/mylxsw/asteria/log" + "github.com/mylxsw/go-utils/must" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "os" + "path/filepath" + "strconv" +) + type VideoRequest struct { // ImagePath The source image used in the video generation process. Please ensure that the source image is in the correct format and dimensions. // Supported Formats: @@ -43,7 +62,136 @@ type VideoResponse struct { Status string `json:"status,omitempty"` } +func (res *VideoResponse) SaveToLocalFiles(ctx context.Context, savePath string) (string, error) { + data, err := base64.StdEncoding.DecodeString(res.Video) + if err != nil { + return "", fmt.Errorf("decode base64 failed: %w", err) + } + + key := filepath.Join(savePath, fmt.Sprintf("%s.%s", must.Must(uuid.GenerateUUID()), "mp4")) + if err := os.WriteFile(key, data, os.ModePerm); err != nil { + return "", fmt.Errorf("write image to file failed: %w", err) + } + + return key, nil +} + +func (res *VideoResponse) UploadResources(ctx context.Context, up *uploader.Uploader, uid int64) (string, error) { + data, err := base64.StdEncoding.DecodeString(res.Video) + if err != nil { + return "", fmt.Errorf("decode base64 failed: %w", err) + } + + ret, err := up.UploadStream(ctx, int(uid), uploader.DefaultUploadExpireAfterDays, data, "mp4") + if err != nil { + return "", fmt.Errorf("upload image to qiniu failed: %w", err) + } + + return ret, nil +} + type VideoError struct { Name string `json:"name,omitempty"` Errors []string `json:"errors,omitempty"` } + +// ImageToVideo Generate a video from an image. +// https://platform.stability.ai/docs/api-reference#tag/v2alphageneration +func (ai *StabilityAI) ImageToVideo(ctx context.Context, imageToVideoReq VideoRequest) (*VideoTaskResponse, error) { + data := &bytes.Buffer{} + writer := multipart.NewWriter(data) + + // Write the init image to the request + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", + fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "image", "image.png")) + h.Set("Content-Type", "image/png") + + imageWriter, _ := writer.CreatePart(h) + imageFile, imageErr := os.Open(imageToVideoReq.ImagePath) + if imageErr != nil { + _ = writer.Close() + return nil, imageErr + } + + _, _ = io.Copy(imageWriter, imageFile) + + if imageToVideoReq.Seed > 0 { + _ = writer.WriteField("seed", strconv.Itoa(imageToVideoReq.Seed)) + } + + if imageToVideoReq.CfgScale > 0 { + _ = writer.WriteField("cfg_scale", strconv.Itoa(imageToVideoReq.CfgScale)) + } + + if imageToVideoReq.MotionBucketID > 0 { + _ = writer.WriteField("motion_bucket_id", strconv.Itoa(imageToVideoReq.MotionBucketID)) + } + + _ = writer.Close() + + // Execute the request + payload := bytes.NewReader(data.Bytes()) + req, _ := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v2alpha/generation/image-to-video", ai.conf.StabilityAIServer[0]), payload) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.Header.Add("Accept", "application/json") + req.Header.Add("Authorization", "Bearer "+ai.conf.StabilityAIKey) + if ai.conf.StabilityAIOrganization != "" { + req.Header.Add("Organization", ai.conf.StabilityAIOrganization) + } + + resp, err := ai.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %v", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody := must.Must(io.ReadAll(resp.Body)) + log.F(log.M{ + "status_code": resp.StatusCode, + "status": resp.Status, + "body": string(respBody), + }).Errorf("failed to decode response body: %v", err) + + return nil, fmt.Errorf("请求失败: %v", string(respBody)) + } + + var body VideoTaskResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("failed to decode response body: %v", err) + } + + return &body, nil +} + +// ImageToVideoResult Get the result of a video generation task. +func (ai *StabilityAI) ImageToVideoResult(ctx context.Context, taskID string) (*VideoResponse, error) { + // Build the request + req, _ := http.NewRequestWithContext(ctx, "GET", ai.conf.StabilityAIServer[0]+"/v2alpha/generation/image-to-video/result/"+taskID, nil) + req.Header.Add("Authorization", "Bearer "+ai.conf.StabilityAIKey) + if ai.conf.StabilityAIOrganization != "" { + req.Header.Add("Organization", ai.conf.StabilityAIOrganization) + } + + req.Header.Add("Accept", "application/json") + + // Execute the request + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return nil, fmt.Errorf("请求失败:%v", string(must.Must(io.ReadAll(resp.Body)))) + } + + var ret VideoResponse + if err := json.NewDecoder(resp.Body).Decode(&ret); err != nil { + return nil, fmt.Errorf("failed to decode response body: %v", err) + } + + return &ret, nil +} diff --git a/pkg/ai/stabilityai/video_test.go b/pkg/ai/stabilityai/video_test.go new file mode 100644 index 0000000..f4f9376 --- /dev/null +++ b/pkg/ai/stabilityai/video_test.go @@ -0,0 +1,96 @@ +package stabilityai_test + +import ( + "context" + "github.com/mylxsw/aidea-server/config" + "github.com/mylxsw/aidea-server/pkg/ai/stabilityai" + "net/http" + "os" + "testing" + "time" +) + +func TestStabilityAI_ImageToVideo(t *testing.T) { + conf := config.Config{ + StabilityAIServer: []string{"https://api.stability.ai"}, + StabilityAIKey: os.Getenv("STABILITY_API_KEY"), + } + + req := stabilityai.VideoRequest{ + ImagePath: "/Users/mylxsw/Downloads/94967c86e10b16829dd4d63cd16b79f5.png", + } + + st := stabilityai.NewStabilityAIWithClient(&conf, &http.Client{Timeout: 60 * time.Second}) + resp, err := st.ImageToVideo(context.TODO(), req) + if err != nil { + t.Fatal(err) + } + + t.Log(resp) + + time.Sleep(5 * time.Second) + + for { + res, err := st.ImageToVideoResult(context.TODO(), resp.ID) + if err != nil { + t.Fatal(err) + } + + t.Log(res) + + if res.Status == "in-progress" { + t.Log("in progress") + time.Sleep(5 * time.Second) + continue + } + + if res.FinishReason == "SUCCESS" { + filepath, err := res.SaveToLocalFiles(context.TODO(), "/tmp") + if err != nil { + t.Fatal(err) + } + + t.Logf("saved as %s", filepath) + break + } + + t.Log("unknown status") + time.Sleep(5 * time.Second) + } +} + +func TestStabilityAI_ImageToVideoResult(t *testing.T) { + conf := config.Config{ + StabilityAIServer: []string{"https://api.stability.ai"}, + StabilityAIKey: os.Getenv("STABILITY_API_KEY"), + } + + id := "57e47215ec64c9ff7c5f3e850e9759249ef1de1da72f6f7e20b89d4a1a527764" + + st := stabilityai.NewStabilityAIWithClient(&conf, &http.Client{Timeout: 60 * time.Second}) + for { + res, err := st.ImageToVideoResult(context.TODO(), id) + if err != nil { + t.Fatal(err) + } + + if res.Video != "" { + filepath, err := res.SaveToLocalFiles(context.TODO(), "/tmp") + if err != nil { + t.Fatal(err) + } + + t.Logf("saved as %s", filepath) + break + } + + if res.Status == "in-progress" { + t.Log("in progress") + time.Sleep(5 * time.Second) + continue + } + + t.Log("unknown status") + time.Sleep(5 * time.Second) + } +} From ac5a14057e9172d0347563e0f93ec0add082ad16 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Wed, 27 Dec 2023 00:07:38 +0800 Subject: [PATCH 05/13] =?UTF-8?q?=E5=9B=BE=E7=94=9F=E8=A7=86=E9=A2=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/coins/price.go | 14 + internal/queue/consumer/provider.go | 1 + internal/queue/image_to_video.go | 319 +++++++++++++++++++++++ internal/queue/provider.go | 20 +- pkg/ai/stabilityai/video.go | 1 + pkg/ai/stabilityai/video_test.go | 18 +- server/controllers/creative-island.go | 3 +- server/controllers/v2/creative-island.go | 225 +++++++++++----- 8 files changed, 519 insertions(+), 82 deletions(-) create mode 100644 internal/queue/image_to_video.go diff --git a/internal/coins/price.go b/internal/coins/price.go index df9f06b..0bead26 100644 --- a/internal/coins/price.go +++ b/internal/coins/price.go @@ -21,6 +21,11 @@ var coinTables = map[string]CoinTable{ "dall-e-2": 20, }, + "video": { + "default": 200, + "stability-image-to-video": 200, + }, + "openai": { // 1000 Token 计费 "gpt-3.5-turbo": 3, // valid $0.002/1K tokens -> ¥0.014/1K tokens @@ -193,6 +198,15 @@ func GetUnifiedImageGenCoins(model string) int { return int(coinTables["image"]["default"]) } +// GetUnifiedVideoGenCoins 统一的视频生成计费 +func GetUnifiedVideoGenCoins(model string) int { + if price, ok := coinTables["video"][model]; ok { + return int(price) + } + + return int(coinTables["video"]["default"]) +} + func GetTextToVoiceCoins(model string, wordCount int) int64 { if price, ok := coinTables["speech"][model]; ok { return int64(math.Ceil(float64(price) * float64(wordCount) / 1000.0)) diff --git a/internal/queue/consumer/provider.go b/internal/queue/consumer/provider.go index 5638f05..384f13a 100644 --- a/internal/queue/consumer/provider.go +++ b/internal/queue/consumer/provider.go @@ -116,6 +116,7 @@ func (p Provider) Boot(resolver infra.Resolver) { mux.HandleFunc(queue.TypeGroupChat, queue.BuildGroupChatHandler(conf, ct, rep, userSvc)) mux.HandleFunc(queue.TypeDalleCompletion, queue.BuildDalleCompletionHandler(dalleClient, uploader, rep)) mux.HandleFunc(queue.TypeArtisticTextCompletion, queue.BuildArtisticTextCompletionHandler(leptonClient, translater, uploader, rep, openaiClient)) + mux.HandleFunc(queue.TypeImageToVideoCompletion, queue.BuildImageToVideoCompletionHandler(stabaiClient, rep)) }) } diff --git a/internal/queue/image_to_video.go b/internal/queue/image_to_video.go new file mode 100644 index 0000000..036fd00 --- /dev/null +++ b/internal/queue/image_to_video.go @@ -0,0 +1,319 @@ +package queue + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/hibiken/asynq" + "github.com/mylxsw/aidea-server/internal/coins" + "github.com/mylxsw/aidea-server/pkg/ai/stabilityai" + "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/repo/model" + "github.com/mylxsw/aidea-server/pkg/uploader" + "github.com/mylxsw/asteria/log" + "os" + "time" +) + +type ImageToVideoCompletionPayload struct { + ID string `json:"id,omitempty"` + Quota int64 `json:"quota,omitempty"` + UID int64 `json:"uid,omitempty"` + + Seed int64 `json:"seed,omitempty"` + Image string `json:"image,omitempty"` + + CreatedAt time.Time `json:"created_at,omitempty"` +} + +func (payload *ImageToVideoCompletionPayload) GetTitle() string { + return "图片转视频" +} + +func (payload *ImageToVideoCompletionPayload) GetID() string { + return payload.ID +} + +func (payload *ImageToVideoCompletionPayload) SetID(id string) { + payload.ID = id +} + +func (payload *ImageToVideoCompletionPayload) GetUID() int64 { + return payload.UID +} + +func (payload *ImageToVideoCompletionPayload) GetQuota() int64 { + return payload.Quota +} + +func (payload *ImageToVideoCompletionPayload) GetModel() string { + return "stability-image-to-video" +} + +func NewImageToVideoCompletionTask(payload any) *asynq.Task { + data, _ := json.Marshal(payload) + return asynq.NewTask(TypeImageToVideoCompletion, data) +} + +type ImageToVideoPendingTaskPayload struct { + TaskID string `json:"task_id,omitempty"` + Payload ImageToVideoCompletionPayload `json:"payload,omitempty"` +} + +func (p ImageToVideoPendingTaskPayload) GetImage() string { + return p.Payload.Image +} + +func (p ImageToVideoPendingTaskPayload) GetID() string { + return p.Payload.GetID() +} + +func (p ImageToVideoPendingTaskPayload) GetUID() int64 { + return p.Payload.UID +} + +func (p ImageToVideoPendingTaskPayload) GetQuota() int64 { + return p.Payload.Quota +} + +func (p ImageToVideoPendingTaskPayload) GetModel() string { + return p.Payload.GetModel() +} + +type ImageToVideoResponse interface { + GetID() string + GetState() string + IsFinished() bool + IsProcessing() bool + UploadResources(ctx context.Context, up *uploader.Uploader, uid int64) ([]string, error) + GetImages() []string +} + +func BuildImageToVideoCompletionHandler( + client *stabilityai.StabilityAI, + rep *repo.Repository, +) TaskHandler { + return func(ctx context.Context, task *asynq.Task) (err error) { + var payload ImageToVideoCompletionPayload + if err := json.Unmarshal(task.Payload(), &payload); err != nil { + return err + } + + if payload.CreatedAt.Add(5 * time.Minute).Before(time.Now()) { + rep.Queue.Update(context.TODO(), payload.GetID(), repo.QueueTaskStatusFailed, ErrorResult{Errors: []string{"任务处理超时"}}) + log.WithFields(log.Fields{"payload": payload}).Errorf("task expired") + return nil + } + + defer func() { + if err2 := recover(); err2 != nil { + log.With(task).Errorf("panic: %v", err2) + err = err2.(error) + + // 更新创作岛历史记录 + if err := rep.Creative.UpdateRecordByTaskID(ctx, payload.GetUID(), payload.GetID(), repo.CreativeRecordUpdateRequest{ + Answer: err.Error(), + Status: repo.CreativeStatusFailed, + }); err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) + } + } + + if err != nil { + if err := rep.Queue.Update( + context.TODO(), + payload.GetID(), + repo.QueueTaskStatusFailed, + ErrorResult{ + Errors: []string{err.Error()}, + }, + ); err != nil { + log.With(task).Errorf("update queue status failed: %s", err) + } + } + }() + + targetImage, err := uploader.DownloadRemoteFile(ctx, payload.Image) + if err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("download remote file failed: %s", err) + panic(err) + } + defer os.Remove(targetImage) + + req := stabilityai.VideoRequest{ + ImagePath: targetImage, + Seed: int(payload.Seed), + } + resp, err := client.ImageToVideo(ctx, req) + if err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("create task failed: %s", err) + panic(err) + } + + if err := rep.Queue.CreatePendingTask(ctx, &repo.PendingTask{ + TaskID: payload.GetID(), + TaskType: TypeImageToVideoCompletion, + NextExecuteAt: time.Now().Add(time.Duration(30) * time.Second), + DeadlineAt: time.Now().Add(30 * time.Minute), + Status: repo.PendingTaskStatusProcessing, + Payload: ImageToVideoPendingTaskPayload{TaskID: resp.ID, Payload: payload}, + }); err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("create pending task failed: %s", err) + panic(err) + } + + return rep.Queue.Update( + context.TODO(), + payload.GetID(), + repo.QueueTaskStatusRunning, + nil, + ) + } +} + +func imageToVideoJobProcesser(que *Queue, client *stabilityai.StabilityAI, up *uploader.Uploader, rep *repo.Repository) PendingTaskHandler { + return func(task *model.QueueTasksPending) (update *repo.PendingTaskUpdate, err error) { + var payload ImageToVideoPendingTaskPayload + if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil { + return nil, err + } + + taskRes, err := client.ImageToVideoResult(context.TODO(), payload.TaskID) + if err != nil { + log.With(payload).Errorf("query fromston job result failed: %v", err) + return &repo.PendingTaskUpdate{ + NextExecuteAt: time.Now().Add(5 * time.Second), + Status: repo.PendingTaskStatusProcessing, + ExecuteTimes: task.ExecuteTimes + 1, + }, nil + } + + defer func() { + if err2 := recover(); err2 != nil { + log.With(task).Errorf("panic: %v", err2) + err = err2.(error) + + // 更新创作岛历史记录 + if err := rep.Creative.UpdateRecordByTaskID(context.TODO(), payload.Payload.GetUID(), payload.Payload.GetID(), repo.CreativeRecordUpdateRequest{ + Answer: err.Error(), + Status: repo.CreativeStatusFailed, + }); err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) + } + + update = &repo.PendingTaskUpdate{Status: repo.PendingTaskStatusFailed} + } + + if err != nil { + if err := rep.Queue.Update( + context.TODO(), + payload.Payload.GetID(), + repo.QueueTaskStatusFailed, + ErrorResult{ + Errors: []string{err.Error()}, + }, + ); err != nil { + log.With(task).Errorf("update queue status failed: %s", err) + } + } + }() + + if taskRes.Video == "" { + if taskRes.Status == "in-progress" { + return &repo.PendingTaskUpdate{ + NextExecuteAt: time.Now().Add(5 * time.Second), + Status: repo.PendingTaskStatusProcessing, + ExecuteTimes: task.ExecuteTimes + 1, + }, nil + } + + log.WithFields(log.Fields{"payload": payload, "res": taskRes}).Errorf("no success task found") + panic(errors.New("no success task found")) + } + + // 任务已经完成,开始处理结果 + // 更新创作岛历史记录 + if err := handleImageToVideoTask(que, payload, taskRes, up, rep); err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) + return nil, err + } + + return &repo.PendingTaskUpdate{Status: repo.PendingTaskStatusSuccess}, nil + } +} + +type ImageToVideoTaskPayload interface { + GetID() string + GetUID() int64 + GetQuota() int64 + GetModel() string + GetImage() string +} + +func handleImageToVideoTask( + que *Queue, + payload ImageToVideoTaskPayload, + tasks *stabilityai.VideoResponse, + up *uploader.Uploader, + rep *repo.Repository, +) error { + videoURL, err := tasks.UploadResources(context.TODO(), up, payload.GetUID()) + if err != nil { + return fmt.Errorf("upload resources failed: %s", err) + } + + resources := make([]string, 0) + resources = append(resources, videoURL) + + if len(resources) == 0 { + log.WithFields(log.Fields{ + "payload": payload, + }).Errorf("没有生成任何视频") + panic(errors.New("没有生成任何视频")) + } + + // 更新创作岛历史记录状态,写入生成的资源地址 + retJson, err := json.Marshal(resources) + if err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) + panic(err) + } + + // 重新计算配额消耗,以实际发生计算 + quotaConsumed := int64(coins.GetUnifiedVideoGenCoins(payload.GetModel()) * len(resources)) + + req := repo.CreativeRecordUpdateRequest{ + Answer: string(retJson), + QuotaUsed: quotaConsumed, + Status: repo.CreativeStatusSuccess, + } + if err := rep.Creative.UpdateRecordByTaskID(context.TODO(), payload.GetUID(), payload.GetID(), req); err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) + panic(err) + } + + // 更新用户配额 + modelUsed := []string{payload.GetModel(), "upload"} + if err := rep.Quota.QuotaConsume( + context.TODO(), + payload.GetUID(), + payload.GetQuota(), + repo.NewQuotaUsedMeta(payload.GetModel(), modelUsed...), + ); err != nil { + log.Errorf("used quota add failed: %s", err) + return err + } + + // 更新队列任务状态 + return rep.Queue.Update( + context.TODO(), + payload.GetID(), + repo.QueueTaskStatusSuccess, + CompletionResult{ + OriginImage: payload.GetImage(), + Resources: resources, + ValidBefore: time.Now().Add(7 * 24 * time.Hour), + }, + ) +} diff --git a/internal/queue/provider.go b/internal/queue/provider.go index 169d05e..6bbf671 100644 --- a/internal/queue/provider.go +++ b/internal/queue/provider.go @@ -6,7 +6,8 @@ import ( "github.com/mylxsw/aidea-server/pkg/ai/dashscope" "github.com/mylxsw/aidea-server/pkg/ai/fromston" "github.com/mylxsw/aidea-server/pkg/ai/leap" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/ai/stabilityai" + "github.com/mylxsw/aidea-server/pkg/repo" "github.com/mylxsw/aidea-server/pkg/service" "github.com/mylxsw/aidea-server/pkg/uploader" "time" @@ -40,9 +41,10 @@ func (Provider) Boot(app infra.Resolver) { leapClient *leap.LeapAI, fromstonClient *fromston.Fromston, dashscopeClient *dashscope.DashScope, + stabilityClient *stabilityai.StabilityAI, up *uploader.Uploader, queue *Queue, - rep *repo2.Repository, + rep *repo.Repository, userSvc *service.UserService, rds *redis.Client, ) { @@ -50,11 +52,12 @@ func (Provider) Boot(app infra.Resolver) { manager.Register(TypeLeapAICompletion, leapAsyncJobProcesser(leapClient, up, rep)) manager.Register(TypeFromStonCompletion, fromStonAsyncJobProcesser(queue, fromstonClient, up, rep)) manager.Register(TypeDashscopeImageCompletion, dashscopeImageAsyncJobProcesser(queue, dashscopeClient, up, rep)) + manager.Register(TypeImageToVideoCompletion, imageToVideoJobProcesser(queue, stabilityClient, up, rep)) // 注册创作岛更新后,自动释放冻结的智慧果任务 - rep.Creative.RegisterRecordStatusUpdateCallback(func(taskID string, userID int64, status repo2.CreativeStatus) { + rep.Creative.RegisterRecordStatusUpdateCallback(func(taskID string, userID int64, status repo.CreativeStatus) { key := fmt.Sprintf("creative-island:%d:task:%s:quota-freeze", userID, taskID) - if status == repo2.CreativeStatusSuccess || status == repo2.CreativeStatusFailed { + if status == repo.CreativeStatusSuccess || status == repo.CreativeStatusFailed { freezedValue, err := rds.Get(context.TODO(), key).Int64() if err != nil { log.F(log.M{"task_id": taskID, "user_id": userID, "status": status}).Errorf("获取创作岛任务冻结的智慧果数量失败:%s", err) @@ -92,6 +95,7 @@ const ( TypeBindPhone = "bind_phone" TypeGroupChat = "group_chat" TypeArtisticTextCompletion = "artistic_text:completion" + TypeImageToVideoCompletion = "image_to_video:completion" ) func ResolveTaskType(category, model string) string { @@ -101,6 +105,10 @@ func ResolveTaskType(category, model string) string { case "deepai": return TypeDeepAICompletion case "stabilityai": + if model == "stability-image-to-video" { + return TypeImageToVideoCompletion + } + return TypeStabilityAICompletion case "fromston": return TypeFromStonCompletion @@ -148,11 +156,11 @@ type Payload interface { // Queue 任务队列 type Queue struct { client *asynq.Client - queueRepo *repo2.QueueRepo + queueRepo *repo.QueueRepo } // NewQueue 创建一个任务队列 -func NewQueue(client *asynq.Client, queueRepo *repo2.QueueRepo) *Queue { +func NewQueue(client *asynq.Client, queueRepo *repo.QueueRepo) *Queue { return &Queue{client: client, queueRepo: queueRepo} } diff --git a/pkg/ai/stabilityai/video.go b/pkg/ai/stabilityai/video.go index b4dc971..7a92928 100644 --- a/pkg/ai/stabilityai/video.go +++ b/pkg/ai/stabilityai/video.go @@ -108,6 +108,7 @@ func (ai *StabilityAI) ImageToVideo(ctx context.Context, imageToVideoReq VideoRe h.Set("Content-Type", "image/png") imageWriter, _ := writer.CreatePart(h) + imageFile, imageErr := os.Open(imageToVideoReq.ImagePath) if imageErr != nil { _ = writer.Close() diff --git a/pkg/ai/stabilityai/video_test.go b/pkg/ai/stabilityai/video_test.go index f4f9376..7f4a6da 100644 --- a/pkg/ai/stabilityai/video_test.go +++ b/pkg/ai/stabilityai/video_test.go @@ -17,7 +17,7 @@ func TestStabilityAI_ImageToVideo(t *testing.T) { } req := stabilityai.VideoRequest{ - ImagePath: "/Users/mylxsw/Downloads/94967c86e10b16829dd4d63cd16b79f5.png", + ImagePath: "/Users/mylxsw/Downloads/IMG_8649.png", } st := stabilityai.NewStabilityAIWithClient(&conf, &http.Client{Timeout: 60 * time.Second}) @@ -36,15 +36,7 @@ func TestStabilityAI_ImageToVideo(t *testing.T) { t.Fatal(err) } - t.Log(res) - - if res.Status == "in-progress" { - t.Log("in progress") - time.Sleep(5 * time.Second) - continue - } - - if res.FinishReason == "SUCCESS" { + if res.Video != "" { filepath, err := res.SaveToLocalFiles(context.TODO(), "/tmp") if err != nil { t.Fatal(err) @@ -54,6 +46,12 @@ func TestStabilityAI_ImageToVideo(t *testing.T) { break } + if res.Status == "in-progress" { + t.Log("in progress") + time.Sleep(5 * time.Second) + continue + } + t.Log("unknown status") time.Sleep(5 * time.Second) } diff --git a/server/controllers/creative-island.go b/server/controllers/creative-island.go index 903b1ad..d685828 100644 --- a/server/controllers/creative-island.go +++ b/server/controllers/creative-island.go @@ -3,6 +3,7 @@ package controllers import ( "context" "encoding/json" + "errors" "fmt" openaiHelper "github.com/mylxsw/aidea-server/pkg/ai/openai" "github.com/mylxsw/aidea-server/pkg/misc" @@ -297,7 +298,7 @@ func (ctl *CreativeIslandController) Item(ctx context.Context, webCtx web.Contex id := webCtx.PathVar("id") island, err := ctl.creativeRepo.Island(ctx, id) if err != nil { - if err == repo2.ErrNotFound { + if errors.Is(err, repo2.ErrNotFound) { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrNotFound), http.StatusNotFound) } diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index d7379e8..b446f02 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -6,8 +6,8 @@ import ( "errors" "fmt" "github.com/mylxsw/aidea-server/pkg/misc" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" - service2 "github.com/mylxsw/aidea-server/pkg/service" + "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/service" "github.com/mylxsw/aidea-server/pkg/uploader" "github.com/mylxsw/aidea-server/pkg/youdao" "math/rand" @@ -39,13 +39,13 @@ const ( // CreativeIslandController 创作岛 type CreativeIslandController struct { conf *config.Config - quotaRepo *repo2.QuotaRepo `autowire:"@"` - queue *queue.Queue `autowire:"@"` - trans youdao.Translater `autowire:"@"` - creativeRepo *repo2.CreativeRepo `autowire:"@"` - securitySrv *service2.SecurityService `autowire:"@"` - userSvc *service2.UserService `autowire:"@"` - rds *redis.Client `autowire:"@"` + quotaRepo *repo.QuotaRepo `autowire:"@"` + queue *queue.Queue `autowire:"@"` + trans youdao.Translater `autowire:"@"` + creativeRepo *repo.CreativeRepo `autowire:"@"` + securitySrv *service.SecurityService `autowire:"@"` + userSvc *service.UserService `autowire:"@"` + rds *redis.Client `autowire:"@"` } // NewCreativeIslandController create a new CreativeIslandController @@ -77,7 +77,8 @@ func (ctl *CreativeIslandController) Register(router web.Router) { // 文生图、图生图 router.Post("/", ctl.Completions) router.Post("/evaluate", ctl.CompletionsEvaluate) - + // 图生视频 + router.Post("/image-to-video", ctl.ImageToVideo) // 图片放大 router.Post("/upscale", ctl.ImageUpscale) // 图片上色 @@ -116,6 +117,18 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte }, } + if client != nil && misc.VersionNewer(client.Version, "1.0.9") && ctl.conf.EnableStabilityAI { + items = append(items, CreativeIslandItem{ + ID: "image-to-video", + Title: "图生视频", + TitleColor: "FFFFFFFF", + PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-to-video-dark.jpg-thumb1000", + RouteURI: "/creative-draw/create-video", + Note: "每次生成视频将消耗 200 智慧果", + Size: SizeLarge, + }) + } + if client != nil && misc.VersionNewer(client.Version, "1.0.8") && ctl.conf.EnableLeptonAI { items = append(items, CreativeIslandItem{ ID: "artistic-text", @@ -206,13 +219,13 @@ func (ctl *CreativeIslandController) Models(ctx context.Context, webCtx web.Cont // loadAllModels 加载所有的模型 // TODO 加缓存 -func (ctl *CreativeIslandController) loadAllModels(ctx context.Context) []repo2.ImageModel { +func (ctl *CreativeIslandController) loadAllModels(ctx context.Context) []repo.ImageModel { models, err := ctl.creativeRepo.Models(ctx) if err != nil { log.Errorf("get models failed: %v", err) } - return array.Filter(models, func(m repo2.ImageModel, _ int) bool { + return array.Filter(models, func(m repo.ImageModel, _ int) bool { if m.Vendor == "leapai" { return ctl.conf.EnableLeapAI } @@ -252,7 +265,7 @@ func (ctl *CreativeIslandController) ImageStyles(ctx context.Context, webCtx web // 查询所有可用的模型,转换为 map[模型ID]模型ID availableModels := array.ToMap( - array.Map(ctl.loadAllModels(ctx), func(item repo2.ImageModel, _ int) string { + array.Map(ctl.loadAllModels(ctx), func(item repo.ImageModel, _ int) string { return item.ModelId }), func(val string, _ int) string { @@ -261,7 +274,7 @@ func (ctl *CreativeIslandController) ImageStyles(ctx context.Context, webCtx web ) // 过滤掉当前没有启用的模型 - filters = array.Filter(filters, func(item repo2.ImageFilter, _ int) bool { + filters = array.Filter(filters, func(item repo.ImageFilter, _ int) bool { _, ok := availableModels[item.ModelId] return ok }) @@ -361,7 +374,7 @@ func (ctl *CreativeIslandController) ShareHistoryItem(ctx context.Context, webCt err := ctl.creativeRepo.ShareCreativeHistoryToGallery(ctx, user.ID, user.Name, int64(hid)) if err != nil { - if errors.Is(err, repo2.ErrNotFound) { + if errors.Is(err, repo.ErrNotFound) { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrNotFound), http.StatusNotFound) } @@ -411,7 +424,7 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C perPage = 20 } - items, meta, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo2.CreativeHistoryQuery{ + items, meta, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo.CreativeHistoryQuery{ Page: page, PerPage: perPage, IslandId: AllInOneIslandID, @@ -423,7 +436,7 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C } // 以下字段不需要返回给前端 - items = array.Map(items, func(item repo2.CreativeHistoryItem, _ int) repo2.CreativeHistoryItem { + items = array.Map(items, func(item repo.CreativeHistoryItem, _ int) repo.CreativeHistoryItem { // Arguments 只保留必须的 image 字段,用于客户端区分是文生图还是图生图 var arguments map[string]any _ = json.Unmarshal([]byte(item.Arguments), &arguments) @@ -441,7 +454,7 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C item.QuotaUsed = 0 switch item.IslandType { - case int64(repo2.IslandTypeImage): + case int64(repo.IslandTypeImage): if arguments != nil { if _, ok := arguments["image"]; ok { item.IslandTitle = "图生图" @@ -451,15 +464,15 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C if item.IslandTitle == "" { item.IslandTitle = "文生图" } - case int64(repo2.IslandTypeUpscale): + case int64(repo.IslandTypeUpscale): item.IslandTitle = "高清修复" - case int64(repo2.IslandTypeImageColorization): + case int64(repo.IslandTypeImageColorization): item.IslandTitle = "图片上色" } // 客户端目前不支持封禁状态展示,这里转换为失败 - if item.Status == int64(repo2.CreativeStatusForbid) { - item.Status = int64(repo2.CreativeStatusFailed) + if item.Status == int64(repo.CreativeStatusForbid) { + item.Status = int64(repo.CreativeStatusFailed) } return item @@ -483,7 +496,7 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C } type CreativeHistoryItemResp struct { - repo2.CreativeHistoryItem + repo.CreativeHistoryItem ShowBetaFeature bool `json:"show_beta_feature,omitempty"` } @@ -501,7 +514,7 @@ func (ctl *CreativeIslandController) HistoryItem(ctx context.Context, webCtx web item, err := ctl.creativeRepo.FindHistoryRecord(ctx, userId, int64(hid)) if err != nil { - if errors.Is(err, repo2.ErrNotFound) { + if errors.Is(err, repo.ErrNotFound) { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrNotFound), http.StatusNotFound) } @@ -510,8 +523,8 @@ func (ctl *CreativeIslandController) HistoryItem(ctx context.Context, webCtx web } // 客户端目前不支持封禁状态展示,这里转换为失败 - if item.Status == int64(repo2.CreativeStatusForbid) { - item.Status = int64(repo2.CreativeStatusFailed) + if item.Status == int64(repo.CreativeStatusForbid) { + item.Status = int64(repo.CreativeStatusFailed) } return webCtx.JSON(CreativeHistoryItemResp{ @@ -716,7 +729,7 @@ func (ctl *CreativeIslandController) resolveImageCompletionRequest(ctx context.C } func (ctl *CreativeIslandController) getAllModels(ctx context.Context) []VendorModel { - return array.Map(ctl.loadAllModels(ctx), func(m repo2.ImageModel, _ int) VendorModel { + return array.Map(ctl.loadAllModels(ctx), func(m repo.ImageModel, _ int) VendorModel { return VendorModel{ ID: m.ModelId, Name: m.ModelName, @@ -769,7 +782,7 @@ func (ctl *CreativeIslandController) getAllImageStyles(ctx context.Context) []Im return []ImageStyle{} } - return array.Map(filters, func(f repo2.ImageFilter, _ int) ImageStyle { + return array.Map(filters, func(f repo.ImageFilter, _ int) ImageStyle { return ImageStyle{ ID: f.Id, Name: f.Name, @@ -822,40 +835,40 @@ func (ctl *CreativeIslandController) getStyleByModelID(ctx context.Context, mode } type VendorModel struct { - ID string `json:"id"` - Name string `json:"name"` - Vendor string `json:"vendor,omitempty"` - Model string `json:"-"` - Enabled bool `json:"-"` - Upscale bool `json:"upscale,omitempty"` - ShowStyle bool `json:"show_style,omitempty"` - ShowImageStrength bool `json:"show_image_strength,omitempty"` - IntroURL string `json:"intro_url,omitempty"` - RatioDimensions map[string]repo2.Dimension `json:"-"` + ID string `json:"id"` + Name string `json:"name"` + Vendor string `json:"vendor,omitempty"` + Model string `json:"-"` + Enabled bool `json:"-"` + Upscale bool `json:"upscale,omitempty"` + ShowStyle bool `json:"show_style,omitempty"` + ShowImageStrength bool `json:"show_image_strength,omitempty"` + IntroURL string `json:"intro_url,omitempty"` + RatioDimensions map[string]repo.Dimension `json:"-"` } -func (vm VendorModel) defaultDimension(ratio string) repo2.Dimension { +func (vm VendorModel) defaultDimension(ratio string) repo.Dimension { switch ratio { case "1:1": - return repo2.Dimension{Width: 512, Height: 512} + return repo.Dimension{Width: 512, Height: 512} case "4:3": - return repo2.Dimension{Width: 768, Height: 576} + return repo.Dimension{Width: 768, Height: 576} case "3:4": - return repo2.Dimension{Width: 576, Height: 768} + return repo.Dimension{Width: 576, Height: 768} case "3:2": - return repo2.Dimension{Width: 768, Height: 512} + return repo.Dimension{Width: 768, Height: 512} case "2:3": - return repo2.Dimension{Width: 512, Height: 768} + return repo.Dimension{Width: 512, Height: 768} case "16:9": - return repo2.Dimension{Width: 1024, Height: 576} + return repo.Dimension{Width: 1024, Height: 576} } - return repo2.Dimension{Width: 512, Height: 512} + return repo.Dimension{Width: 512, Height: 512} } -func (vm VendorModel) GetDimension(ratio string) repo2.Dimension { +func (vm VendorModel) GetDimension(ratio string) repo.Dimension { if vm.RatioDimensions == nil { - vm.RatioDimensions = map[string]repo2.Dimension{} + vm.RatioDimensions = map[string]repo.Dimension{} } dimension, ok := vm.RatioDimensions[ratio] @@ -927,14 +940,14 @@ func (ctl *CreativeIslandController) ImageUpscale(ctx context.Context, webCtx we log.F(log.M{"user_id": user.ID, "quota": req.Quota, "task_id": taskID}).Errorf("创作岛用户配额已冻结,更新 Redis 任务与配额关系失败: %s", err) } - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: AllInOneIslandID, - IslandType: repo2.IslandTypeUpscale, + IslandType: repo.IslandTypeUpscale, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } - arg := repo2.CreativeRecordArguments{ + arg := repo.CreativeRecordArguments{ Image: image, UpscaleBy: upscaleBy, } @@ -998,14 +1011,14 @@ func (ctl *CreativeIslandController) ImageColorize(ctx context.Context, webCtx w log.F(log.M{"user_id": user.ID, "quota": req.Quota, "task_id": taskID}).Errorf("创作岛用户配额已冻结,更新 Redis 任务与配额关系失败: %s", err) } - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: AllInOneIslandID, - IslandType: repo2.IslandTypeImageColorization, + IslandType: repo.IslandTypeImageColorization, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } - arg := repo2.CreativeRecordArguments{ + arg := repo.CreativeRecordArguments{ Image: image, } @@ -1134,15 +1147,15 @@ func (ctl *CreativeIslandController) ArtisticText(ctx context.Context, webCtx we log.F(log.M{"user_id": user.ID, "quota": req.Quota, "task_id": taskID}).Errorf("创作岛用户配额已冻结,更新 Redis 任务与配额关系失败: %s", err) } - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: AllInOneIslandID, - IslandType: repo2.IslandTypeArtisticText, + IslandType: repo.IslandTypeArtisticText, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, Prompt: prompt, } - arg := repo2.CreativeRecordArguments{ + arg := repo.CreativeRecordArguments{ NegativePrompt: negativePrompt, ArtisticType: optType, StylePreset: stylePreset, @@ -1232,16 +1245,16 @@ func (ctl *CreativeIslandController) Completions(ctx context.Context, webCtx web } // buildHistorySaveRecord 构建保存历史记录的 CreativeItem -func (*CreativeIslandController) buildHistorySaveRecord(req *queue.ImageCompletionPayload, taskID string) (repo2.CreativeItem, repo2.CreativeRecordArguments) { - creativeItem := repo2.CreativeItem{ +func (*CreativeIslandController) buildHistorySaveRecord(req *queue.ImageCompletionPayload, taskID string) (repo.CreativeItem, repo.CreativeRecordArguments) { + creativeItem := repo.CreativeItem{ IslandId: AllInOneIslandID, - IslandType: repo2.IslandTypeImage, + IslandType: repo.IslandTypeImage, IslandModel: req.Model, Prompt: req.Prompt, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } - return creativeItem, repo2.CreativeRecordArguments{ + return creativeItem, repo.CreativeRecordArguments{ NegativePrompt: req.NegativePrompt, PromptTags: req.PromptTags, Width: req.Width, @@ -1262,3 +1275,85 @@ func (*CreativeIslandController) buildHistorySaveRecord(req *queue.ImageCompleti Seed: req.Seed, } } + +// ImageToVideo 图片生成视频 +// 请求参数: +// - image 图片上传后的地址 +// - seed 随机种子 +func (ctl *CreativeIslandController) ImageToVideo(ctx context.Context, webCtx web.Context, user *auth.User) web.Response { + image := webCtx.Input("image") + if image != "" && !str.HasPrefixes(image, []string{"http://", "https://"}) { + return webCtx.JSONError("invalid image", http.StatusBadRequest) + } + + // 图片地址检查 + if !strings.HasPrefix(image, ctl.conf.StorageDomain) { + return webCtx.JSONError("invalid image", http.StatusBadRequest) + } + + image = uploader.BuildImageURLWithFilter(image, "resize1024x576", ctl.conf.StorageDomain) + + // 检查用户是否有足够的智慧果 + quota, err := ctl.userSvc.UserQuota(ctx, user.ID) + if err != nil { + log.Errorf("get user quota failed: %s", err) + return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInternalError), http.StatusInternalServerError) + } + + quotaConsume := int64(coins.GetUnifiedVideoGenCoins("")) + if quota.Rest-quota.Freezed < quotaConsume { + return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrQuotaNotEnough), http.StatusPaymentRequired) + } + + seed := webCtx.Int64Input("seed", -1) + if seed < 0 || seed > 2147483647 { + seed = -1 + } + + req := queue.ImageToVideoCompletionPayload{ + Quota: quotaConsume, + CreatedAt: time.Now(), + Image: image, + UID: user.ID, + Seed: seed, + } + + // 加入异步任务队列 + taskID, err := ctl.queue.Enqueue(&req, queue.NewImageToVideoCompletionTask) + if err != nil { + log.Errorf("enqueue task failed: %s", err) + return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInternalError), http.StatusInternalServerError) + } + log.WithFields(log.Fields{"task_id": taskID}).Debugf("enqueue task success: %s", taskID) + + // 冻结智慧果 + if err := ctl.userSvc.FreezeUserQuota(ctx, user.ID, req.Quota); err != nil { + log.F(log.M{"user_id": user.ID, "quota": req.Quota, "task_id": taskID}).Errorf("创作岛冻结用户配额失败: %s", err) + } + + if err := ctl.rds.SetEx(ctx, fmt.Sprintf("creative-island:%d:task:%s:quota-freeze", user.ID, taskID), req.Quota, 5*time.Minute).Err(); err != nil { + log.F(log.M{"user_id": user.ID, "quota": req.Quota, "task_id": taskID}).Errorf("创作岛用户配额已冻结,更新 Redis 任务与配额关系失败: %s", err) + } + + creativeItem := repo.CreativeItem{ + IslandId: AllInOneIslandID, + IslandType: repo.IslandTypeVideo, + TaskId: taskID, + Status: repo.CreativeStatusPending, + } + + arg := repo.CreativeRecordArguments{ + Image: image, + } + + // 保存历史记录 + if _, err := ctl.creativeRepo.CreateRecordWithArguments(ctx, user.ID, &creativeItem, &arg); err != nil { + log.Errorf("create creative item failed: %v", err) + return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInternalError), http.StatusInternalServerError) + } + + return webCtx.JSON(web.M{ + "task_id": taskID, // 任务 ID + "wait": 30, // 等待时间 + }) +} From e430520b940be3bc6c74a21ea6272a0ab0fe68af Mon Sep 17 00:00:00 2001 From: mylxsw Date: Wed, 27 Dec 2023 00:13:35 +0800 Subject: [PATCH 06/13] update --- internal/queue/payment.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/internal/queue/payment.go b/internal/queue/payment.go index 6816ae3..c8d3d26 100644 --- a/internal/queue/payment.go +++ b/internal/queue/payment.go @@ -3,10 +3,11 @@ package queue import ( "context" "encoding/json" + "errors" "fmt" "github.com/mylxsw/aidea-server/pkg/dingding" "github.com/mylxsw/aidea-server/pkg/mail" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/repo" "time" "github.com/hibiken/asynq" @@ -57,7 +58,7 @@ func NewPaymentTask(payload any) *asynq.Task { } func BuildPaymentHandler( - rep *repo2.Repository, + rep *repo.Repository, mailer *mail.Sender, que *Queue, ding *dingding.Dingding, @@ -78,7 +79,7 @@ func BuildPaymentHandler( if err := rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusFailed, + repo.QueueTaskStatusFailed, ErrorResult{ Errors: []string{err.Error()}, }, @@ -86,7 +87,7 @@ func BuildPaymentHandler( log.With(task).Errorf("update queue status failed: %s", err) } - if err := rep.Event.UpdateEvent(ctx, payload.EventID, repo2.EventStatusFailed); err != nil { + if err := rep.Event.UpdateEvent(ctx, payload.EventID, repo.EventStatusFailed); err != nil { log.WithFields(log.Fields{"event_id": payload.EventID}).Errorf("update event status failed: %s", err) } } @@ -95,7 +96,7 @@ func BuildPaymentHandler( // 查询事件记录 event, err := rep.Event.GetEvent(ctx, payload.EventID) if err != nil { - if err == repo2.ErrNotFound { + if errors.Is(err, repo.ErrNotFound) { log.WithFields(log.Fields{"event_id": payload.EventID}).Errorf("event not found") return nil } @@ -104,12 +105,12 @@ func BuildPaymentHandler( return err } - if event.Status != repo2.EventStatusWaiting { + if event.Status != repo.EventStatusWaiting { log.WithFields(log.Fields{"event_id": payload.EventID}).Warningf("event status is not waiting") return nil } - if event.EventType != repo2.EventTypePaymentCompleted { + if event.EventType != repo.EventTypePaymentCompleted { log.With(payload).Errorf("event type is not payment_completed") return nil } @@ -127,7 +128,7 @@ func BuildPaymentHandler( return err } - if err := rep.Event.UpdateEvent(ctx, payload.EventID, repo2.EventStatusSucceed); err != nil { + if err := rep.Event.UpdateEvent(ctx, payload.EventID, repo.EventStatusSucceed); err != nil { log.WithFields(log.Fields{"event_id": payload.EventID}).Errorf("update event status failed: %s", err) } @@ -136,7 +137,7 @@ func BuildPaymentHandler( mailPayload := &MailPayload{ To: []string{payload.Email}, Subject: "充值已到账", - Body: fmt.Sprintf("您充值的 %d 个智慧果已到账,有效期至 %s,请尽快使用。", product.Quota, repo2.TimeInDate(expiredAt).Format(time.RFC3339)), + Body: fmt.Sprintf("您充值的 %d 个智慧果已到账,有效期至 %s,请尽快使用。", product.Quota, repo.TimeInDate(expiredAt).Format(time.RFC3339)), CreatedAt: time.Now(), } @@ -148,7 +149,7 @@ func BuildPaymentHandler( // 邀请人奖励 user, err := rep.User.GetUserByID(ctx, payload.UserID) if err != nil { - if err != repo2.ErrUserAccountDisabled { + if !errors.Is(err, repo.ErrUserAccountDisabled) { log.WithFields(log.Fields{"user_id": payload.UserID}).Errorf("引荐人奖励,查询用户信息失败: %s", err) } } else { @@ -168,11 +169,11 @@ func BuildPaymentHandler( `用户(ID:%d)充值了 %d 个智慧果,有效期至 %s,充值订单号为 %s,充值来源为 %s。`, payload.UserID, product.Quota, - repo2.TimeInDate(expiredAt).Format("2006-01-02"), + repo.TimeInDate(expiredAt).Format("2006-01-02"), payload.PaymentID, payload.Source, ) - if err := ding.Send(dingding.NewMarkdownMessage(payload.Env+": 有用户充值啦", content, []string{})); err != nil { + if err := ding.Send(dingding.NewMarkdownMessage(fmt.Sprintf("%s: 充值 %d 个智慧果", payload.Env, product.Quota), content, []string{})); err != nil { log.Errorf("发送钉钉通知失败: %s", err) } }() @@ -180,7 +181,7 @@ func BuildPaymentHandler( return rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusSuccess, + repo.QueueTaskStatusSuccess, EmptyResult{}, ) } From 94650a96278442661cf6b7b24e2327fe77c04e09 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Wed, 27 Dec 2023 12:29:33 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E8=A7=86=E9=A2=91=E5=90=88=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/coins/price.go | 12 + pkg/repo/creative.go | 337 ++++++++++++----------- server/controllers/creative-island.go | 102 ++++--- server/controllers/v2/creative-island.go | 28 +- 4 files changed, 270 insertions(+), 209 deletions(-) diff --git a/internal/coins/price.go b/internal/coins/price.go index 0bead26..4a76eef 100644 --- a/internal/coins/price.go +++ b/internal/coins/price.go @@ -198,6 +198,18 @@ func GetUnifiedImageGenCoins(model string) int { return int(coinTables["image"]["default"]) } +// GetImageGenCoinsExcept 获取除了指定价格的所有图片生成模型 +func GetImageGenCoinsExcept(coins int64) map[string]int64 { + coinsTable := make(map[string]int64) + for model, price := range coinTables["image"] { + if price != coins { + coinsTable[model] = price + } + } + + return coinsTable +} + // GetUnifiedVideoGenCoins 统一的视频生成计费 func GetUnifiedVideoGenCoins(model string) int { if price, ok := coinTables["video"][model]; ok { diff --git a/pkg/repo/creative.go b/pkg/repo/creative.go index 9f346e3..40f156e 100644 --- a/pkg/repo/creative.go +++ b/pkg/repo/creative.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "github.com/mylxsw/aidea-server/pkg/misc" - model2 "github.com/mylxsw/aidea-server/pkg/repo/model" + "github.com/mylxsw/aidea-server/pkg/repo/model" "strings" "time" @@ -177,7 +177,7 @@ type CreativeIsland struct { Extension CreativeIslandExt `json:"extension,omitempty"` } -func buildCreativeIslandFromModel(item model2.CreativeIslandN) CreativeIsland { +func buildCreativeIslandFromModel(item model.CreativeIslandN) CreativeIsland { var ext CreativeIslandExt if !item.Ext.IsZero() && item.Ext.String != "" { if err := json.Unmarshal([]byte(item.Ext.ValueOrZero()), &ext); err != nil { @@ -220,22 +220,22 @@ func buildCreativeIslandFromModel(item model2.CreativeIslandN) CreativeIsland { func (r *CreativeRepo) Islands(ctx context.Context) ([]CreativeIsland, error) { q := query.Builder(). - Where(model2.FieldCreativeIslandStatus, int64(IslandStatusEnabled)). - OrderBy(model2.FieldCreativeIslandPriority, "DESC"). - OrderBy(model2.FieldCreativeIslandId, "ASC") - items, err := model2.NewCreativeIslandModel(r.db).Get(ctx, q) + Where(model.FieldCreativeIslandStatus, int64(IslandStatusEnabled)). + OrderBy(model.FieldCreativeIslandPriority, "DESC"). + OrderBy(model.FieldCreativeIslandId, "ASC") + items, err := model.NewCreativeIslandModel(r.db).Get(ctx, q) if err != nil { return nil, err } - return array.Map(items, func(item model2.CreativeIslandN, _ int) CreativeIsland { + return array.Map(items, func(item model.CreativeIslandN, _ int) CreativeIsland { return buildCreativeIslandFromModel(item) }), nil } func (r *CreativeRepo) Island(ctx context.Context, islandId string) (*CreativeIsland, error) { - q := query.Builder().Where(model2.FieldCreativeIslandIslandId, islandId) - item, err := model2.NewCreativeIslandModel(r.db).First(ctx, q) + q := query.Builder().Where(model.FieldCreativeIslandIslandId, islandId) + item, err := model.NewCreativeIslandModel(r.db).First(ctx, q) if err != nil { if err == query.ErrNoResult { return nil, ErrNotFound @@ -248,16 +248,16 @@ func (r *CreativeRepo) Island(ctx context.Context, islandId string) (*CreativeIs } func (r *CreativeRepo) CreateRecord(ctx context.Context, userId int64, item *CreativeItem) (int64, error) { - return model2.NewCreativeHistoryModel(r.db).Create(ctx, query.KV{ - model2.FieldCreativeHistoryUserId: userId, - model2.FieldCreativeHistoryIslandId: item.IslandId, - model2.FieldCreativeHistoryIslandType: int64(item.IslandType), - model2.FieldCreativeHistoryIslandModel: item.IslandModel, - model2.FieldCreativeHistoryArguments: item.Arguments, - model2.FieldCreativeHistoryPrompt: item.Prompt, - model2.FieldCreativeHistoryAnswer: item.Answer, - model2.FieldCreativeHistoryTaskId: item.TaskId, - model2.FieldCreativeHistoryStatus: int64(item.Status), + return model.NewCreativeHistoryModel(r.db).Create(ctx, query.KV{ + model.FieldCreativeHistoryUserId: userId, + model.FieldCreativeHistoryIslandId: item.IslandId, + model.FieldCreativeHistoryIslandType: int64(item.IslandType), + model.FieldCreativeHistoryIslandModel: item.IslandModel, + model.FieldCreativeHistoryArguments: item.Arguments, + model.FieldCreativeHistoryPrompt: item.Prompt, + model.FieldCreativeHistoryAnswer: item.Answer, + model.FieldCreativeHistoryTaskId: item.TaskId, + model.FieldCreativeHistoryStatus: int64(item.Status), }) } @@ -267,16 +267,16 @@ func (r *CreativeRepo) CreateRecordWithArguments(ctx context.Context, userId int item.Arguments = string(arguments) } - id, err := model2.NewCreativeHistoryModel(r.db).Create(ctx, query.KV{ - model2.FieldCreativeHistoryUserId: userId, - model2.FieldCreativeHistoryIslandId: item.IslandId, - model2.FieldCreativeHistoryIslandType: int64(item.IslandType), - model2.FieldCreativeHistoryIslandModel: item.IslandModel, - model2.FieldCreativeHistoryArguments: item.Arguments, - model2.FieldCreativeHistoryPrompt: item.Prompt, - model2.FieldCreativeHistoryAnswer: item.Answer, - model2.FieldCreativeHistoryTaskId: item.TaskId, - model2.FieldCreativeHistoryStatus: int64(item.Status), + id, err := model.NewCreativeHistoryModel(r.db).Create(ctx, query.KV{ + model.FieldCreativeHistoryUserId: userId, + model.FieldCreativeHistoryIslandId: item.IslandId, + model.FieldCreativeHistoryIslandType: int64(item.IslandType), + model.FieldCreativeHistoryIslandModel: item.IslandModel, + model.FieldCreativeHistoryArguments: item.Arguments, + model.FieldCreativeHistoryPrompt: item.Prompt, + model.FieldCreativeHistoryAnswer: item.Answer, + model.FieldCreativeHistoryTaskId: item.TaskId, + model.FieldCreativeHistoryStatus: int64(item.Status), }) if err != nil { return 0, err @@ -294,10 +294,10 @@ func (r *CreativeRepo) CreateRecordWithArguments(ctx context.Context, userId int } func (r *CreativeRepo) UpdateRecordByID(ctx context.Context, userId, id int64, answer string, quotaUsed int64, status CreativeStatus) error { - q := query.Builder().Where(model2.FieldCreativeHistoryId, id). - Where(model2.FieldCreativeHistoryUserId, userId) + q := query.Builder().Where(model.FieldCreativeHistoryId, id). + Where(model.FieldCreativeHistoryUserId, userId) - _, err := model2.NewCreativeHistoryModel(r.db).Update(ctx, q, model2.CreativeHistoryN{ + _, err := model.NewCreativeHistoryModel(r.db).Update(ctx, q, model.CreativeHistoryN{ Answer: null.StringFrom(answer), Status: null.IntFrom(int64(status)), QuotaUsed: null.IntFrom(quotaUsed), @@ -306,8 +306,8 @@ func (r *CreativeRepo) UpdateRecordByID(ctx context.Context, userId, id int64, a } func (r *CreativeRepo) UpdateRecordStatusByID(ctx context.Context, id int64, answer string, status CreativeStatus) error { - q := query.Builder().Where(model2.FieldCreativeHistoryId, id) - _, err := model2.NewCreativeHistoryModel(r.db).Update(ctx, q, model2.CreativeHistoryN{ + q := query.Builder().Where(model.FieldCreativeHistoryId, id) + _, err := model.NewCreativeHistoryModel(r.db).Update(ctx, q, model.CreativeHistoryN{ Status: null.IntFrom(int64(status)), Answer: null.StringFrom(answer), }) @@ -315,20 +315,20 @@ func (r *CreativeRepo) UpdateRecordStatusByID(ctx context.Context, id int64, ans } func (r *CreativeRepo) UpdateRecordAnswerByTaskID(ctx context.Context, userId int64, taskID string, answer string) error { - q := query.Builder().Where(model2.FieldCreativeHistoryTaskId, taskID). - Where(model2.FieldCreativeHistoryUserId, userId) + q := query.Builder().Where(model.FieldCreativeHistoryTaskId, taskID). + Where(model.FieldCreativeHistoryUserId, userId) - _, err := model2.NewCreativeHistoryModel(r.db).Update(ctx, q, model2.CreativeHistoryN{ + _, err := model.NewCreativeHistoryModel(r.db).Update(ctx, q, model.CreativeHistoryN{ Answer: null.StringFrom(answer), }) return err } func (r *CreativeRepo) UpdateRecordAnswerByID(ctx context.Context, userId int64, historyID int64, answer string) error { - q := query.Builder().Where(model2.FieldCreativeHistoryId, historyID). - Where(model2.FieldCreativeHistoryUserId, userId) + q := query.Builder().Where(model.FieldCreativeHistoryId, historyID). + Where(model.FieldCreativeHistoryUserId, userId) - _, err := model2.NewCreativeHistoryModel(r.db).Update(ctx, q, model2.CreativeHistoryN{ + _, err := model.NewCreativeHistoryModel(r.db).Update(ctx, q, model.CreativeHistoryN{ Answer: null.StringFrom(answer), }) return err @@ -347,10 +347,10 @@ type CreativeRecordUpdateExtArgs struct { } func (r *CreativeRepo) UpdateRecordArgumentsByTaskID(ctx context.Context, userId int64, taskID string, ext CreativeRecordUpdateExtArgs) error { - q := query.Builder().Where(model2.FieldCreativeHistoryTaskId, taskID). - Where(model2.FieldCreativeHistoryUserId, userId) + q := query.Builder().Where(model.FieldCreativeHistoryTaskId, taskID). + Where(model.FieldCreativeHistoryUserId, userId) - original, err := model2.NewCreativeHistoryModel(r.db).First(ctx, q) + original, err := model.NewCreativeHistoryModel(r.db).First(ctx, q) if err != nil { return err } @@ -371,11 +371,11 @@ func (r *CreativeRepo) UpdateRecordArgumentsByTaskID(ctx context.Context, userId } argData, _ := json.Marshal(arg) - update := model2.CreativeHistoryN{ + update := model.CreativeHistoryN{ Arguments: null.StringFrom(string(argData)), } - _, err = model2.NewCreativeHistoryModel(r.db).Update(ctx, q, update) + _, err = model.NewCreativeHistoryModel(r.db).Update(ctx, q, update) return err } @@ -394,17 +394,17 @@ func (r *CreativeRepo) UpdateRecordByTaskID(ctx context.Context, userId int64, t } }() - q := query.Builder().Where(model2.FieldCreativeHistoryTaskId, taskID). - Where(model2.FieldCreativeHistoryUserId, userId) + q := query.Builder().Where(model.FieldCreativeHistoryTaskId, taskID). + Where(model.FieldCreativeHistoryUserId, userId) - update := model2.CreativeHistoryN{ + update := model.CreativeHistoryN{ Answer: null.StringFrom(req.Answer), Status: null.IntFrom(int64(req.Status)), QuotaUsed: null.IntFrom(req.QuotaUsed), } if req.ExtArguments != nil { - original, err := model2.NewCreativeHistoryModel(r.db).First(ctx, q) + original, err := model.NewCreativeHistoryModel(r.db).First(ctx, q) if err != nil { return err } @@ -428,17 +428,17 @@ func (r *CreativeRepo) UpdateRecordByTaskID(ctx context.Context, userId int64, t update.Arguments = null.StringFrom(string(argData)) } - _, err := model2.NewCreativeHistoryModel(r.db).Update(ctx, q, update) + _, err := model.NewCreativeHistoryModel(r.db).Update(ctx, q, update) return err } -func (r *CreativeRepo) FindHistoryRecordByTaskId(ctx context.Context, userId int64, taskId string) (*model2.CreativeHistory, error) { +func (r *CreativeRepo) FindHistoryRecordByTaskId(ctx context.Context, userId int64, taskId string) (*model.CreativeHistory, error) { q := query.Builder(). - Where(model2.FieldCreativeHistoryUserId, userId). - Where(model2.FieldCreativeHistoryTaskId, taskId). - OrderBy(model2.FieldCreativeHistoryId, "DESC") + Where(model.FieldCreativeHistoryUserId, userId). + Where(model.FieldCreativeHistoryTaskId, taskId). + OrderBy(model.FieldCreativeHistoryId, "DESC") - item, err := model2.NewCreativeHistoryModel(r.db).First(ctx, q) + item, err := model.NewCreativeHistoryModel(r.db).First(ctx, q) if err != nil { if err == query.ErrNoResult { return nil, ErrNotFound @@ -452,13 +452,13 @@ func (r *CreativeRepo) FindHistoryRecordByTaskId(ctx context.Context, userId int func (r *CreativeRepo) FindHistoryRecord(ctx context.Context, userId, id int64) (*CreativeHistoryItem, error) { q := query.Builder(). - Where(model2.FieldCreativeHistoryId, id) + Where(model.FieldCreativeHistoryId, id) if userId > 0 { - q = q.Where(model2.FieldCreativeHistoryUserId, userId) + q = q.Where(model.FieldCreativeHistoryUserId, userId) } - item, err := model2.NewCreativeHistoryModel(r.db).First(ctx, q) + item, err := model.NewCreativeHistoryModel(r.db).First(ctx, q) if err != nil { if errors.Is(err, query.ErrNoResult) { return nil, ErrNotFound @@ -510,40 +510,40 @@ type CreativeHistoryQuery struct { func (r *CreativeRepo) HistoryRecordPaginate(ctx context.Context, userId int64, req CreativeHistoryQuery) ([]CreativeHistoryItem, query.PaginateMeta, error) { q := query.Builder(). - OrderBy(model2.FieldCreativeHistoryId, "DESC") + OrderBy(model.FieldCreativeHistoryId, "DESC") if userId > 0 { - q = q.Where(model2.FieldCreativeHistoryUserId, userId) + q = q.Where(model.FieldCreativeHistoryUserId, userId) } switch req.Mode { case "creative-island": - q = q.Where(model2.FieldCreativeHistoryIslandType, int64(IslandTypeText)) + q = q.Where(model.FieldCreativeHistoryIslandType, int64(IslandTypeText)) case "image-draw": - q = q.Where(model2.FieldCreativeHistoryIslandType, int64(IslandTypeImage)) + q = q.Where(model.FieldCreativeHistoryIslandType, int64(IslandTypeImage)) default: } if req.IslandId != "" { - q = q.Where(model2.FieldCreativeHistoryIslandId, req.IslandId) + q = q.Where(model.FieldCreativeHistoryIslandId, req.IslandId) } if req.IslandModel != "" { - q = q.Where(model2.FieldCreativeHistoryIslandModel, req.IslandModel) + q = q.Where(model.FieldCreativeHistoryIslandModel, req.IslandModel) } - items, meta, err := model2.NewCreativeHistoryModel(r.db).Paginate(ctx, req.Page, req.PerPage, q) + items, meta, err := model.NewCreativeHistoryModel(r.db).Paginate(ctx, req.Page, req.PerPage, q) if err != nil { return nil, query.PaginateMeta{}, err } islandIDNames := make(map[string]string) - islandQ := query.Builder().Select(model2.FieldCreativeIslandIslandId, model2.FieldCreativeIslandTitle) + islandQ := query.Builder().Select(model.FieldCreativeIslandIslandId, model.FieldCreativeIslandTitle) if req.IslandId != "" { - islandQ = islandQ.Where(model2.FieldCreativeIslandIslandId, req.IslandId) + islandQ = islandQ.Where(model.FieldCreativeIslandIslandId, req.IslandId) } - islands, err := model2.NewCreativeIslandModel(r.db).Get( + islands, err := model.NewCreativeIslandModel(r.db).Get( ctx, islandQ, ) @@ -553,7 +553,7 @@ func (r *CreativeRepo) HistoryRecordPaginate(ctx context.Context, userId int64, } } - ret := array.Map(items, func(item model2.CreativeHistoryN, _ int) CreativeHistoryItem { + ret := array.Map(items, func(item model.CreativeHistoryN, _ int) CreativeHistoryItem { answer := item.Answer.ValueOrZero() if item.IslandType.ValueOrZero() == int64(IslandTypeText) { answer = misc.SubString(answer, 100) @@ -581,53 +581,61 @@ func (r *CreativeRepo) HistoryRecordPaginate(ctx context.Context, userId int64, func (r *CreativeRepo) DeleteHistoryRecord(ctx context.Context, userId, id int64) error { q := query.Builder(). - Where(model2.FieldCreativeHistoryId, id). - Where(model2.FieldCreativeHistoryUserId, userId) + Where(model.FieldCreativeHistoryId, id). + Where(model.FieldCreativeHistoryUserId, userId) - _, err := model2.NewCreativeHistoryModel(r.db).Delete(ctx, q) + _, err := model.NewCreativeHistoryModel(r.db).Delete(ctx, q) return err } func (r *CreativeRepo) UserGallery(ctx context.Context, userID int64, islandModel string, limit int64) ([]CreativeHistoryItem, error) { q := query.Builder(). // Where(model.FieldCreativeHistoryStatus, int64(CreativeStatusSuccess)). - Where(model2.FieldCreativeHistoryIslandType, int64(IslandTypeImage)). + WhereIn( + model.FieldCreativeHistoryIslandType, + int64(IslandTypeImage), + int64(IslandTypeArtisticText), + int64(IslandTypeImageColorization), + int64(IslandTypeUpscale), + int64(IslandTypeVideo), + ). Select( - model2.FieldCreativeHistoryId, - model2.FieldCreativeHistoryIslandId, - model2.FieldCreativeHistoryIslandType, - model2.FieldCreativeHistoryAnswer, - model2.FieldCreativeHistoryStatus, - model2.FieldCreativeHistoryUserId, - model2.FieldCreativeHistoryCreatedAt, - model2.FieldCreativeHistoryUpdatedAt, + model.FieldCreativeHistoryId, + model.FieldCreativeHistoryIslandId, + model.FieldCreativeHistoryIslandType, + model.FieldCreativeHistoryAnswer, + model.FieldCreativeHistoryArguments, + model.FieldCreativeHistoryStatus, + model.FieldCreativeHistoryUserId, + model.FieldCreativeHistoryCreatedAt, + model.FieldCreativeHistoryUpdatedAt, ). - OrderBy(model2.FieldCreativeHistoryId, "DESC"). + OrderBy(model.FieldCreativeHistoryId, "DESC"). Limit(limit) if islandModel != "" { - q = q.Where(model2.FieldCreativeHistoryIslandModel, islandModel) + q = q.Where(model.FieldCreativeHistoryIslandModel, islandModel) } if userID != 0 { - q = q.Where(model2.FieldCreativeHistoryUserId, userID) + q = q.Where(model.FieldCreativeHistoryUserId, userID) } - items, err := model2.NewCreativeHistoryModel(r.db).Get(ctx, q) + items, err := model.NewCreativeHistoryModel(r.db).Get(ctx, q) if err != nil { return nil, err } - islandIDNames := make(map[string]string) - islandQ := query.Builder().Select(model2.FieldCreativeIslandIslandId, model2.FieldCreativeIslandTitle) - islands, err := model2.NewCreativeIslandModel(r.db).Get(ctx, islandQ) - if err == nil { - for _, island := range islands { - islandIDNames[island.IslandId.ValueOrZero()] = island.Title.ValueOrZero() - } + islandTypeNames := map[int64]string{ + int64(IslandTypeImage): "图片生成", + int64(IslandTypeVideo): "视频合成", + int64(IslandTypeAudio): "语音合成", + int64(IslandTypeUpscale): "超分辨率", + int64(IslandTypeImageColorization): "图片上色", + int64(IslandTypeArtisticText): "艺术字", } - ret := array.Map(items, func(item model2.CreativeHistoryN, _ int) CreativeHistoryItem { + ret := array.Map(items, func(item model.CreativeHistoryN, _ int) CreativeHistoryItem { return CreativeHistoryItem{ Id: item.Id.ValueOrZero(), IslandId: item.IslandId.ValueOrZero(), @@ -637,7 +645,8 @@ func (r *CreativeRepo) UserGallery(ctx context.Context, userID int64, islandMode UpdatedAt: item.UpdatedAt.ValueOrZero(), Status: item.Status.ValueOrZero(), UserID: item.UserId.ValueOrZero(), - IslandName: islandIDNames[item.IslandId.ValueOrZero()], + Arguments: item.Arguments.ValueOrZero(), + IslandName: islandTypeNames[item.IslandType.ValueOrZero()], } }) @@ -694,55 +703,55 @@ const ( CreativeGalleryStatusDeleted = 3 ) -func (r *CreativeRepo) Gallery(ctx context.Context, page, perPage int64) ([]model2.CreativeGallery, query.PaginateMeta, error) { - ids, meta, err := model2.NewCreativeGalleryRandomModel(r.db).Paginate(ctx, page, perPage, query.Builder()) +func (r *CreativeRepo) Gallery(ctx context.Context, page, perPage int64) ([]model.CreativeGallery, query.PaginateMeta, error) { + ids, meta, err := model.NewCreativeGalleryRandomModel(r.db).Paginate(ctx, page, perPage, query.Builder()) if err != nil { return nil, meta, err } - randomIds := array.Map(ids, func(item model2.CreativeGalleryRandomN, _ int) any { + randomIds := array.Map(ids, func(item model.CreativeGalleryRandomN, _ int) any { return item.GalleryId.ValueOrZero() }) if len(randomIds) == 0 { meta.LastPage = 1 - return []model2.CreativeGallery{}, meta, nil + return []model.CreativeGallery{}, meta, nil } q := query.Builder(). - WhereIn(model2.FieldCreativeGalleryId, randomIds). + WhereIn(model.FieldCreativeGalleryId, randomIds). Select( - model2.FieldCreativeGalleryId, - model2.FieldCreativeGalleryUserId, - model2.FieldCreativeGalleryUsername, - model2.FieldCreativeGalleryCreativeType, - model2.FieldCreativeGalleryPrompt, - model2.FieldCreativeGalleryAnswer, - model2.FieldCreativeGalleryTags, - model2.FieldCreativeGalleryRefCount, - model2.FieldCreativeGalleryStarLevel, - model2.FieldCreativeGalleryHotValue, - model2.FieldCreativeGalleryCreatedAt, - model2.FieldCreativeGalleryUpdatedAt, + model.FieldCreativeGalleryId, + model.FieldCreativeGalleryUserId, + model.FieldCreativeGalleryUsername, + model.FieldCreativeGalleryCreativeType, + model.FieldCreativeGalleryPrompt, + model.FieldCreativeGalleryAnswer, + model.FieldCreativeGalleryTags, + model.FieldCreativeGalleryRefCount, + model.FieldCreativeGalleryStarLevel, + model.FieldCreativeGalleryHotValue, + model.FieldCreativeGalleryCreatedAt, + model.FieldCreativeGalleryUpdatedAt, ). - OrderBy(model2.FieldCreativeGalleryHotValue, "DESC"). + OrderBy(model.FieldCreativeGalleryHotValue, "DESC"). OrderByRaw("RAND()") - items, err := model2.NewCreativeGalleryModel(r.db).Get(ctx, q) + items, err := model.NewCreativeGalleryModel(r.db).Get(ctx, q) if err != nil { return nil, meta, err } - return array.Map(items, func(item model2.CreativeGalleryN, _ int) model2.CreativeGallery { + return array.Map(items, func(item model.CreativeGalleryN, _ int) model.CreativeGallery { return item.ToCreativeGallery() }), meta, err } -func (r *CreativeRepo) GalleryByID(ctx context.Context, id int64) (*model2.CreativeGallery, error) { +func (r *CreativeRepo) GalleryByID(ctx context.Context, id int64) (*model.CreativeGallery, error) { q := query.Builder(). - Where(model2.FieldCreativeGalleryId, id) + Where(model.FieldCreativeGalleryId, id) - item, err := model2.NewCreativeGalleryModel(r.db).First(ctx, q) + item, err := model.NewCreativeGalleryModel(r.db).First(ctx, q) if err != nil { if errors.Is(err, query.ErrNoResult) { return nil, ErrNotFound @@ -775,10 +784,10 @@ func (r *CreativeRepo) ShareCreativeHistoryToGallery(ctx context.Context, userID return eloquent.Transaction(r.db, func(tx query.Database) error { // 查询创作岛历史纪录信息 q := query.Builder(). - Where(model2.FieldCreativeHistoryId, id). - Where(model2.FieldCreativeHistoryUserId, userID) + Where(model.FieldCreativeHistoryId, id). + Where(model.FieldCreativeHistoryUserId, userID) - item, err := model2.NewCreativeHistoryModel(tx).First(ctx, q) + item, err := model.NewCreativeHistoryModel(tx).First(ctx, q) if err != nil { if errors.Is(err, query.ErrNoResult) { return ErrNotFound @@ -788,9 +797,9 @@ func (r *CreativeRepo) ShareCreativeHistoryToGallery(ctx context.Context, userID } // 查询是否已经在 Gallery 中 - existItem, err := model2.NewCreativeGalleryModel(tx).First( + existItem, err := model.NewCreativeGalleryModel(tx).First( ctx, - query.Builder().Where(model2.FieldCreativeGalleryCreativeHistoryId, id), + query.Builder().Where(model.FieldCreativeGalleryCreativeHistoryId, id), ) if err != nil && !errors.Is(err, query.ErrNoResult) { return err @@ -800,19 +809,19 @@ func (r *CreativeRepo) ShareCreativeHistoryToGallery(ctx context.Context, userID // 已经存在,且已经删除,则恢复 if existItem.Status.ValueOrZero() == CreativeGalleryStatusDeleted { item.Shared = null.IntFrom(int64(IslandHistorySharedStatusShared)) - if err := item.Save(ctx, model2.FieldCreativeHistoryShared); err != nil { + if err := item.Save(ctx, model.FieldCreativeHistoryShared); err != nil { return err } existItem.Status = null.IntFrom(CreativeGalleryStatusOK) - return existItem.Save(ctx, model2.FieldCreativeGalleryStatus) + return existItem.Save(ctx, model.FieldCreativeGalleryStatus) } return nil } item.Shared = null.IntFrom(int64(IslandHistorySharedStatusShared)) - if err := item.Save(ctx, model2.FieldCreativeHistoryShared); err != nil { + if err := item.Save(ctx, model.FieldCreativeHistoryShared); err != nil { return err } @@ -830,16 +839,16 @@ func (r *CreativeRepo) ShareCreativeHistoryToGallery(ctx context.Context, userID } meta, _ := json.Marshal(arg.ToGalleryMeta()) - _, err = model2.NewCreativeGalleryModel(tx).Create(ctx, query.KV{ - model2.FieldCreativeGalleryUserId: userID, - model2.FieldCreativeGalleryUsername: username, - model2.FieldCreativeGalleryCreativeHistoryId: id, - model2.FieldCreativeGalleryCreativeType: item.IslandType.ValueOrZero(), - model2.FieldCreativeGalleryPrompt: prompt, - model2.FieldCreativeGalleryAnswer: item.Answer.ValueOrZero(), - model2.FieldCreativeGalleryStatus: CreativeGalleryStatusOK, - model2.FieldCreativeGalleryNegativePrompt: arg.NegativePrompt, - model2.FieldCreativeGalleryMeta: string(meta), + _, err = model.NewCreativeGalleryModel(tx).Create(ctx, query.KV{ + model.FieldCreativeGalleryUserId: userID, + model.FieldCreativeGalleryUsername: username, + model.FieldCreativeGalleryCreativeHistoryId: id, + model.FieldCreativeGalleryCreativeType: item.IslandType.ValueOrZero(), + model.FieldCreativeGalleryPrompt: prompt, + model.FieldCreativeGalleryAnswer: item.Answer.ValueOrZero(), + model.FieldCreativeGalleryStatus: CreativeGalleryStatusOK, + model.FieldCreativeGalleryNegativePrompt: arg.NegativePrompt, + model.FieldCreativeGalleryMeta: string(meta), }) return err }) @@ -848,12 +857,12 @@ func (r *CreativeRepo) ShareCreativeHistoryToGallery(ctx context.Context, userID func (r *CreativeRepo) CancelCreativeHistoryShare(ctx context.Context, userID int64, historyID int64) error { return eloquent.Transaction(r.db, func(tx query.Database) error { q := query.Builder(). - Where(model2.FieldCreativeGalleryCreativeHistoryId, historyID) + Where(model.FieldCreativeGalleryCreativeHistoryId, historyID) if userID > 0 { - q = q.Where(model2.FieldCreativeGalleryUserId, userID) + q = q.Where(model.FieldCreativeGalleryUserId, userID) } - item, err := model2.NewCreativeGalleryModel(tx).First(ctx, q) + item, err := model.NewCreativeGalleryModel(tx).First(ctx, q) if err != nil { if errors.Is(err, query.ErrNoResult) { return nil @@ -862,11 +871,11 @@ func (r *CreativeRepo) CancelCreativeHistoryShare(ctx context.Context, userID in return err } - historyItem, err := model2.NewCreativeHistoryModel(tx).First( + historyItem, err := model.NewCreativeHistoryModel(tx).First( ctx, query.Builder(). - Where(model2.FieldCreativeHistoryId, historyID). - Where(model2.FieldCreativeHistoryUserId, item.UserId), + Where(model.FieldCreativeHistoryId, historyID). + Where(model.FieldCreativeHistoryUserId, item.UserId), ) if err != nil && !errors.Is(err, query.ErrNoResult) { return err @@ -874,18 +883,18 @@ func (r *CreativeRepo) CancelCreativeHistoryShare(ctx context.Context, userID in if historyItem != nil { historyItem.Shared = null.IntFrom(int64(IslandHistorySharedStatusNotShared)) - if err := historyItem.Save(ctx, model2.FieldCreativeHistoryShared); err != nil { + if err := historyItem.Save(ctx, model.FieldCreativeHistoryShared); err != nil { return err } } item.Status = null.IntFrom(CreativeGalleryStatusDeleted) - return item.Save(ctx, model2.FieldCreativeGalleryStatus) + return item.Save(ctx, model.FieldCreativeGalleryStatus) }) } type ImageModel struct { - model2.ImageModel + model.ImageModel ImageMeta ImageModelMeta `json:"image_meta"` } @@ -905,8 +914,8 @@ type Dimension struct { } func (r *CreativeRepo) Model(ctx context.Context, vendor, realModel string) (*ImageModel, error) { - q := query.Builder().Where(model2.FieldImageModelVendor, vendor).Where(model2.FieldImageModelRealModel, realModel) - mod, err := model2.NewImageModelModel(r.db).First(ctx, q) + q := query.Builder().Where(model.FieldImageModelVendor, vendor).Where(model.FieldImageModelRealModel, realModel) + mod, err := model.NewImageModelModel(r.db).First(ctx, q) if err != nil { if err == query.ErrNoResult { return nil, ErrNotFound @@ -929,16 +938,16 @@ func (r *CreativeRepo) Model(ctx context.Context, vendor, realModel string) (*Im func (r *CreativeRepo) Models(ctx context.Context) ([]ImageModel, error) { q := query.Builder(). - Where(model2.FieldImageModelStatus, 1). - OrderBy(model2.FieldImageModelVendor, "ASC"). - OrderBy(model2.FieldImageModelModelName, "ASC") + Where(model.FieldImageModelStatus, 1). + OrderBy(model.FieldImageModelVendor, "ASC"). + OrderBy(model.FieldImageModelModelName, "ASC") - items, err := model2.NewImageModelModel(r.db).Get(ctx, q) + items, err := model.NewImageModelModel(r.db).Get(ctx, q) if err != nil { return nil, err } - return array.Map(items, func(item model2.ImageModelN, _ int) ImageModel { + return array.Map(items, func(item model.ImageModelN, _ int) ImageModel { m := item.ToImageModel() var meta ImageModelMeta if m.Meta != "" { @@ -954,7 +963,7 @@ func (r *CreativeRepo) Models(ctx context.Context) ([]ImageModel, error) { } type ImageFilter struct { - model2.ImageFilter + model.ImageFilter Vendor string `json:"-"` ImageMeta ImageFilterMeta `json:"meta"` } @@ -998,10 +1007,10 @@ func (meta ImageFilterMeta) ShouldUseTemplate(prompt string) bool { // modelVendors 查询所有的模型(模型 id->模型服务商) func (r *CreativeRepo) modelVendors(ctx context.Context) (map[string]string, error) { q := query.Builder(). - Where(model2.FieldImageModelStatus, 1). - Select(model2.FieldImageModelModelId, model2.FieldImageModelVendor) + Where(model.FieldImageModelStatus, 1). + Select(model.FieldImageModelModelId, model.FieldImageModelVendor) - items, err := model2.NewImageModelModel(r.db).Get(ctx, q) + items, err := model.NewImageModelModel(r.db).Get(ctx, q) if err != nil { return nil, err } @@ -1016,10 +1025,10 @@ func (r *CreativeRepo) modelVendors(ctx context.Context) (map[string]string, err func (r *CreativeRepo) Filters(ctx context.Context) ([]ImageFilter, error) { q := query.Builder(). - Where(model2.FieldImageFilterStatus, 1). - OrderBy(model2.FieldImageFilterId, "DESC") + Where(model.FieldImageFilterStatus, 1). + OrderBy(model.FieldImageFilterId, "DESC") - items, err := model2.NewImageFilterModel(r.db).Get(ctx, q) + items, err := model.NewImageFilterModel(r.db).Get(ctx, q) if err != nil { return nil, err } @@ -1027,14 +1036,14 @@ func (r *CreativeRepo) Filters(ctx context.Context) ([]ImageFilter, error) { modelVenders, err := r.modelVendors(ctx) if err == nil { // 过滤掉模型不存在的风格 - items = array.Filter(items, func(item model2.ImageFilterN, _ int) bool { + items = array.Filter(items, func(item model.ImageFilterN, _ int) bool { return modelVenders[item.ModelId.ValueOrZero()] != "" }) } else { log.Errorf("get model venders failed: %v", err) } - return array.Map(items, func(item model2.ImageFilterN, _ int) ImageFilter { + return array.Map(items, func(item model.ImageFilterN, _ int) ImageFilter { m := item.ToImageFilter() var meta ImageFilterMeta if m.Meta != "" { @@ -1053,12 +1062,12 @@ func (r *CreativeRepo) Filters(ctx context.Context) ([]ImageFilter, error) { func (r *CreativeRepo) Filter(ctx context.Context, id int64) (*ImageFilter, error) { q := query.Builder(). - Where(model2.FieldImageFilterStatus, 1). - Where(model2.FieldImageFilterId, id) + Where(model.FieldImageFilterStatus, 1). + Where(model.FieldImageFilterId, id) - item, err := model2.NewImageFilterModel(r.db).First(ctx, q) + item, err := model.NewImageFilterModel(r.db).First(ctx, q) if err != nil { - if err == query.ErrNoResult { + if errors.Is(err, query.ErrNoResult) { return nil, ErrNotFound } diff --git a/server/controllers/creative-island.go b/server/controllers/creative-island.go index d685828..868596e 100644 --- a/server/controllers/creative-island.go +++ b/server/controllers/creative-island.go @@ -7,7 +7,7 @@ import ( "fmt" openaiHelper "github.com/mylxsw/aidea-server/pkg/ai/openai" "github.com/mylxsw/aidea-server/pkg/misc" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/repo" "github.com/mylxsw/aidea-server/pkg/service" "github.com/mylxsw/aidea-server/pkg/youdao" "math/rand" @@ -32,11 +32,11 @@ import ( // CreativeIslandController 创作岛 type CreativeIslandController struct { conf *config.Config - quotaRepo *repo2.QuotaRepo `autowire:"@"` + quotaRepo *repo.QuotaRepo `autowire:"@"` queue *queue.Queue `autowire:"@"` - queueRepo *repo2.QueueRepo `autowire:"@"` + queueRepo *repo.QueueRepo `autowire:"@"` trans youdao.Translater `autowire:"@"` - creativeRepo *repo2.CreativeRepo `autowire:"@"` + creativeRepo *repo.CreativeRepo `autowire:"@"` securitySrv *service.SecurityService `autowire:"@"` } @@ -65,6 +65,11 @@ func (ctl *CreativeIslandController) Register(router web.Router) { }) } +type CreativeHistoryItem struct { + repo.CreativeHistoryItem + PreviewImage string `json:"preview_image,omitempty"` +} + // gallery 创作岛项目的图库 func (ctl *CreativeIslandController) gallery(ctx context.Context, webCtx web.Context, user *auth.User) web.Response { mode := webCtx.InputWithDefault("mode", "default") @@ -84,13 +89,28 @@ func (ctl *CreativeIslandController) gallery(ctx context.Context, webCtx web.Con } return webCtx.JSON(web.M{ - "data": array.Map(items, func(item repo2.CreativeHistoryItem, _ int) repo2.CreativeHistoryItem { + "data": array.Map(items, func(item repo.CreativeHistoryItem, _ int) CreativeHistoryItem { if item.UserID != user.ID && userId != 0 { // 客户端处理:如果用户ID为0,则该项目不可点击 item.UserID = 0 } - return item + ret := CreativeHistoryItem{ + CreativeHistoryItem: item, + } + + if item.Arguments != "" && (item.IslandType == int64(repo.IslandTypeVideo) || (item.Answer == "" || item.Answer == "[]")) { + var arg map[string]any + _ = json.Unmarshal([]byte(item.Arguments), &arg) + + if arg["image"] != nil { + ret.PreviewImage = arg["image"].(string) + } + } + + item.Arguments = "" + + return ret }), }) } @@ -112,7 +132,7 @@ func (ctl *CreativeIslandController) histories(ctx context.Context, webCtx web.C return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInvalidRequest), http.StatusBadRequest) } - items, meta, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo2.CreativeHistoryQuery{ + items, meta, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo.CreativeHistoryQuery{ Page: page, PerPage: perPage, Mode: mode, @@ -123,10 +143,10 @@ func (ctl *CreativeIslandController) histories(ctx context.Context, webCtx web.C } return webCtx.JSON(web.M{ - "data": array.Map(items, func(item repo2.CreativeHistoryItem, _ int) repo2.CreativeHistoryItem { + "data": array.Map(items, func(item repo.CreativeHistoryItem, _ int) repo.CreativeHistoryItem { // 客户端目前不支持封禁状态展示,这里转换为失败 - if item.Status == int64(repo2.CreativeStatusForbid) { - item.Status = int64(repo2.CreativeStatusFailed) + if item.Status == int64(repo.CreativeStatusForbid) { + item.Status = int64(repo.CreativeStatusFailed) } return item @@ -144,7 +164,7 @@ func (ctl *CreativeIslandController) itemHistories(ctx context.Context, webCtx w if id == "" { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInvalidRequest), http.StatusBadRequest) } - items, _, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo2.CreativeHistoryQuery{ + items, _, err := ctl.creativeRepo.HistoryRecordPaginate(ctx, user.ID, repo.CreativeHistoryQuery{ IslandId: id, Page: 1, PerPage: 100, @@ -155,10 +175,10 @@ func (ctl *CreativeIslandController) itemHistories(ctx context.Context, webCtx w } return webCtx.JSON(web.M{ - "data": array.Map(items, func(item repo2.CreativeHistoryItem, _ int) repo2.CreativeHistoryItem { + "data": array.Map(items, func(item repo.CreativeHistoryItem, _ int) repo.CreativeHistoryItem { // 客户端目前不支持封禁状态展示,这里转换为失败 - if item.Status == int64(repo2.CreativeStatusForbid) { - item.Status = int64(repo2.CreativeStatusFailed) + if item.Status == int64(repo.CreativeStatusForbid) { + item.Status = int64(repo.CreativeStatusFailed) } return item @@ -180,7 +200,7 @@ func (ctl *CreativeIslandController) historyItem(ctx context.Context, webCtx web item, err := ctl.creativeRepo.FindHistoryRecord(ctx, user.ID, int64(hid)) if err != nil { - if err == repo2.ErrNotFound { + if err == repo.ErrNotFound { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrNotFound), http.StatusNotFound) } @@ -189,8 +209,8 @@ func (ctl *CreativeIslandController) historyItem(ctx context.Context, webCtx web } // 客户端目前不支持封禁状态展示,这里转换为失败 - if item.Status == int64(repo2.CreativeStatusForbid) { - item.Status = int64(repo2.CreativeStatusFailed) + if item.Status == int64(repo.CreativeStatusForbid) { + item.Status = int64(repo.CreativeStatusFailed) } return webCtx.JSON(item) @@ -235,7 +255,7 @@ func (ctl *CreativeIslandController) List(ctx context.Context, webCtx web.Contex return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrInternalError), http.StatusInternalServerError) } - islands = array.Filter(islands, func(item repo2.CreativeIsland, _ int) bool { + islands = array.Filter(islands, func(item repo.CreativeIsland, _ int) bool { if item.VersionMax == "" && item.VersionMin == "" { return true } @@ -258,26 +278,26 @@ func (ctl *CreativeIslandController) List(ctx context.Context, webCtx web.Contex switch mode { case "creative-island": categories = []string{"热门", "创作", "生活", "职场", "娱乐"} - islands = array.Filter(islands, func(item repo2.CreativeIsland, _ int) bool { + islands = array.Filter(islands, func(item repo.CreativeIsland, _ int) bool { return !array.In(CreativeIslandModelType(item.ModelType), imageTypes) }) - items = array.Map(islands, func(item repo2.CreativeIsland, _ int) CreativeIslandItem { + items = array.Map(islands, func(item repo.CreativeIsland, _ int) CreativeIslandItem { return CreativeIslandItemFromModel(item) }) case "image-draw": categories = []string{"图生图", "文生图"} backgroundImage = "https://img.freepik.com/free-vector/modern-colorful-soft-watercolor-texture-background_1035-22725.jpg" - islands = array.Filter(islands, func(item repo2.CreativeIsland, _ int) bool { + islands = array.Filter(islands, func(item repo.CreativeIsland, _ int) bool { return array.In(CreativeIslandModelType(item.ModelType), imageTypes) }) - items = array.Map(islands, func(item repo2.CreativeIsland, _ int) CreativeIslandItem { + items = array.Map(islands, func(item repo.CreativeIsland, _ int) CreativeIslandItem { // 不能暴漏给客户端的字段 item.Extension.AIPrompt = "" return CreativeIslandItemFromModel(item) }) default: categories = []string{"热门", "绘图", "创作", "生活", "职场", "娱乐"} - items = array.Map(islands, func(item repo2.CreativeIsland, _ int) CreativeIslandItem { + items = array.Map(islands, func(item repo.CreativeIsland, _ int) CreativeIslandItem { // 不能暴漏给客户端的字段 item.Extension.AIPrompt = "" return CreativeIslandItemFromModel(item) @@ -298,7 +318,7 @@ func (ctl *CreativeIslandController) Item(ctx context.Context, webCtx web.Contex id := webCtx.PathVar("id") island, err := ctl.creativeRepo.Island(ctx, id) if err != nil { - if errors.Is(err, repo2.ErrNotFound) { + if errors.Is(err, repo.ErrNotFound) { return webCtx.JSONError(common.Text(webCtx, ctl.trans, common.ErrNotFound), http.StatusNotFound) } @@ -472,21 +492,21 @@ func (ctl *CreativeIslandController) completionsDeepAI(ctx context.Context, webC } log.WithFields(log.Fields{"task_id": taskID}).Debugf("enqueue task success: %s", taskID) - arguments, _ := json.Marshal(repo2.CreativeRecordArguments{ + arguments, _ := json.Marshal(repo.CreativeRecordArguments{ NegativePrompt: negativePrompt, Width: int64(width), Height: int64(height), StylePreset: stylePreset, }) - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: item.ID, - IslandType: repo2.IslandTypeImage, + IslandType: repo.IslandTypeImage, IslandModel: stylePreset, Arguments: string(arguments), Prompt: prompt, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } if _, err := ctl.creativeRepo.CreateRecord(ctx, user.ID, &creativeItem); err != nil { @@ -607,7 +627,7 @@ func (ctl *CreativeIslandController) completionsStabilityAI(ctx context.Context, } log.WithFields(log.Fields{"task_id": taskID}).Debugf("enqueue task success: %s", taskID) - arguments, _ := json.Marshal(repo2.CreativeRecordArguments{ + arguments, _ := json.Marshal(repo.CreativeRecordArguments{ NegativePrompt: negativePrompt, Width: int64(width), Height: int64(height), @@ -617,14 +637,14 @@ func (ctl *CreativeIslandController) completionsStabilityAI(ctx context.Context, Image: image, }) - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: item.ID, - IslandType: repo2.IslandTypeImage, + IslandType: repo.IslandTypeImage, IslandModel: item.Model, Arguments: string(arguments), Prompt: prompt, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } if _, err := ctl.creativeRepo.CreateRecord(ctx, user.ID, &creativeItem); err != nil { @@ -752,7 +772,7 @@ func (ctl *CreativeIslandController) completionsLeapAI(ctx context.Context, webC } log.WithFields(log.Fields{"task_id": taskID}).Debugf("enqueue task success: %s", taskID) - arguments, _ := json.Marshal(repo2.CreativeRecordArguments{ + arguments, _ := json.Marshal(repo.CreativeRecordArguments{ NegativePrompt: negativePrompt, Width: int64(width), Height: int64(height), @@ -763,14 +783,14 @@ func (ctl *CreativeIslandController) completionsLeapAI(ctx context.Context, webC UpscaleBy: upscaleBy, }) - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: item.ID, - IslandType: repo2.IslandTypeImage, + IslandType: repo.IslandTypeImage, IslandModel: item.Model, Arguments: string(arguments), Prompt: prompt, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } if _, err := ctl.creativeRepo.CreateRecord(ctx, user.ID, &creativeItem); err != nil { @@ -871,14 +891,14 @@ func (ctl *CreativeIslandController) completionsOpenAI(ctx context.Context, webC "word_count": wordCount, }) - creativeItem := repo2.CreativeItem{ + creativeItem := repo.CreativeItem{ IslandId: item.ID, - IslandType: repo2.IslandTypeText, + IslandType: repo.IslandTypeText, IslandModel: item.Model, Arguments: string(arguments), Prompt: prompt, TaskId: taskID, - Status: repo2.CreativeStatusPending, + Status: repo.CreativeStatusPending, } if _, err := ctl.creativeRepo.CreateRecord(ctx, user.ID, &creativeItem); err != nil { @@ -928,10 +948,10 @@ type CreativeIslandItem struct { ShowImageStyleSelector bool `yaml:"show_image_style_selector,omitempty" json:"show_image_style_selector,omitempty"` NoPrompt bool `yaml:"no_prompt,omitempty" json:"no_prompt,omitempty"` - Extension repo2.CreativeIslandExt `yaml:"extension" json:"extension"` + Extension repo.CreativeIslandExt `yaml:"extension" json:"extension"` } -func CreativeIslandItemFromModel(item repo2.CreativeIsland) CreativeIslandItem { +func CreativeIslandItemFromModel(item repo.CreativeIsland) CreativeIslandItem { wordCount := item.WordCount if wordCount <= 0 { wordCount = 1000 diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index b446f02..c60dffc 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -106,6 +106,22 @@ const ( ) func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Context, user *auth.UserOptional, client *auth.ClientInfo) web.Response { + imageCost := int64(coins.GetUnifiedImageGenCoins("")) + videoCost := int64(coins.GetUnifiedVideoGenCoins("stability-image-to-video")) + + imageModelsCost := coins.GetImageGenCoinsExcept(imageCost) + imageModelsCostNote := "" + if len(imageModelsCost) > 0 { + ns := make([]string, 0) + for mod, cost := range imageModelsCost { + ns = append(ns, fmt.Sprintf("%s %d/张", strings.ToUpper(mod), cost)) + } + + if len(ns) > 0 { + imageModelsCostNote = fmt.Sprintf("(以下模型除外,%s)", strings.Join(ns, ",")) + } + } + items := []CreativeIslandItem{ { ID: "text-to-image", @@ -113,18 +129,19 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-text-to-image.jpeg-thumb1000", RouteURI: "/creative-draw/create?mode=text-to-image&id=text-to-image", + Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果%s。", imageCost, imageModelsCostNote), Size: SizeLarge, }, } - if client != nil && misc.VersionNewer(client.Version, "1.0.9") && ctl.conf.EnableStabilityAI { + if client != nil && misc.VersionNewer(client.Version, "1.0.10") && ctl.conf.EnableStabilityAI { items = append(items, CreativeIslandItem{ ID: "image-to-video", Title: "图生视频", TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-to-video-dark.jpg-thumb1000", RouteURI: "/creative-draw/create-video", - Note: "每次生成视频将消耗 200 智慧果", + Note: fmt.Sprintf("图生视频功能能够将静态的图片转换为动态的视频,转换后的视频时长为 2s。生成每个视频将消耗 %d 智慧果。", videoCost), Size: SizeLarge, }) } @@ -136,6 +153,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/art-text-bg.jpg-thumb1000", RouteURI: "/creative-draw/artistic-text?type=text&id=artistic-text", + Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeLarge, }) items = append(items, CreativeIslandItem{ @@ -144,6 +162,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/art-qr-bg.jpg-thumb1000", RouteURI: "/creative-draw/artistic-text?type=qr&id=artistic-qr", + Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) } @@ -155,6 +174,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-image-to-image.jpeg-thumb1000", RouteURI: "/creative-draw/create?mode=image-to-image&id=image-to-image", Tag: ternary.If(client != nil && client.IsIOS(), "", "BETA"), + Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) @@ -165,7 +185,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/super-res.jpeg-thumb1000", RouteURI: "/creative-draw/create-upscale", - Note: "图片的高清修复功能能够把低分辨率的照片升级到高分辨率,让图片的清晰度得到明显提升。", + Note: fmt.Sprintf("图片的高清修复功能能够把低分辨率的照片升级到高分辨率,让图片的清晰度得到明显提升。\n生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) @@ -175,7 +195,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-colorizev2.jpeg-thumb1000", RouteURI: "/creative-draw/create-colorize", - Note: "图片上色功能能够把黑白照片变成彩色照片,让照片的色彩更加丰富。", + Note: fmt.Sprintf("图片上色功能能够把黑白照片变成彩色照片,让照片的色彩更加丰富。\n生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) } From c011550e84754ffc4a733c498210676b4baaedee Mon Sep 17 00:00:00 2001 From: mylxsw Date: Wed, 27 Dec 2023 17:01:11 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E5=85=8D=E8=B4=B9=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E4=BB=85=E9=99=90=20iOS=20=E5=B9=B3=E5=8F=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/queue/payment.go | 2 +- server/controllers/openai.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/queue/payment.go b/internal/queue/payment.go index c8d3d26..eea89e2 100644 --- a/internal/queue/payment.go +++ b/internal/queue/payment.go @@ -173,7 +173,7 @@ func BuildPaymentHandler( payload.PaymentID, payload.Source, ) - if err := ding.Send(dingding.NewMarkdownMessage(fmt.Sprintf("%s: 充值 %d 个智慧果", payload.Env, product.Quota), content, []string{})); err != nil { + if err := ding.Send(dingding.NewMarkdownMessage(fmt.Sprintf("用户 %d 充值 %d 个智慧果(%s)", payload.UserID, product.Quota, payload.Env), content, []string{})); err != nil { log.Errorf("发送钉钉通知失败: %s", err) } }() diff --git a/server/controllers/openai.go b/server/controllers/openai.go index 2c9d4c4..bc8d09e 100644 --- a/server/controllers/openai.go +++ b/server/controllers/openai.go @@ -196,7 +196,7 @@ func (m FinalMessage) ToJSON() string { // Chat 聊天接口,接口参数参考 https://platform.openai.com/docs/api-reference/chat/create // 该接口会返回一个 SSE 流,接口参数 stream 总是为 true(忽略客户端设置) func (ctl *OpenAIController) Chat(ctx context.Context, webCtx web.Context, user *auth.UserOptional, quotaRepo *repo2.QuotaRepo, w http.ResponseWriter, client *auth.ClientInfo) { - if user.User == nil && ctl.conf.FreeChatEnabled { + if user.User == nil && ctl.conf.FreeChatEnabled && client.IsIOS() { // 匿名用户访问 user.User = &auth.User{ ID: 0, From 1226fc931a81606f18eccc793273235ff078680a Mon Sep 17 00:00:00 2001 From: mylxsw Date: Thu, 28 Dec 2023 00:23:53 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E5=9B=BE=E7=89=87=E8=BD=AC=E8=A7=86?= =?UTF-8?q?=E9=A2=91=E6=97=B6=EF=BC=8C=E6=A0=B9=E6=8D=AE=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E5=AE=BD=E9=AB=98=E6=9D=A5=E5=AF=B9=E5=9B=BE=E7=89=87=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E8=A3=81=E5=89=AA=EF=BC=8C=E6=9B=B4=E5=A5=BD=E7=9A=84?= =?UTF-8?q?=E9=80=82=E9=85=8D=E4=B8=89=E7=A7=8D=E4=B8=8D=E9=80=9A=E5=B0=BA?= =?UTF-8?q?=E5=AF=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/uploader/downloader.go | 49 ++++++++++++++++++++++++ pkg/uploader/downloader_test.go | 20 ++++++++++ server/controllers/v2/creative-island.go | 16 +++++++- 3 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 pkg/uploader/downloader_test.go diff --git a/pkg/uploader/downloader.go b/pkg/uploader/downloader.go index 385170c..3d68c1a 100644 --- a/pkg/uploader/downloader.go +++ b/pkg/uploader/downloader.go @@ -3,6 +3,7 @@ package uploader import ( "context" "encoding/base64" + "encoding/json" "fmt" "github.com/mylxsw/go-utils/ternary" "io" @@ -39,6 +40,18 @@ func BuildImageURLWithFilter(remoteURL string, filter, storageDomain string) str return remoteURL } +// RemoveImageFilter 移除图片的 filter +func RemoveImageFilter(imageURL string) string { + if str.HasSuffixes(strings.ToLower(imageURL), supportFilters) { + segs := strings.Split(imageURL, "-") + segs = segs[:len(segs)-1] + + return strings.Join(segs, "-") + } + + return imageURL +} + var ( ErrFileForbidden = fmt.Errorf("文件违规已被禁用") ) @@ -127,3 +140,39 @@ func DownloadRemoteFileAsBase64Raw(ctx context.Context, remoteURL string) (image } return base64.StdEncoding.EncodeToString(data), http.DetectContentType(data), nil } + +type ImageInfo struct { + // Size 文件大小,单位:Bytes + Size int64 `json:"size"` + // Format 图片类型,如png、jpeg、gif、bmp等 + Format string `json:"format"` + Width int64 `json:"width"` + Height int64 `json:"height"` +} + +func QueryImageInfo(imageURL string) (*ImageInfo, error) { + resp, err := http.Get(RemoveImageFilter(imageURL) + "?imageInfo") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + if resp.StatusCode == http.StatusForbidden { + return nil, ErrFileForbidden + } + return nil, fmt.Errorf("query remote file info failed: [%d] %s", resp.StatusCode, resp.Status) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var info ImageInfo + if err := json.Unmarshal(data, &info); err != nil { + return nil, err + } + + return &info, nil +} diff --git a/pkg/uploader/downloader_test.go b/pkg/uploader/downloader_test.go new file mode 100644 index 0000000..05ce2d9 --- /dev/null +++ b/pkg/uploader/downloader_test.go @@ -0,0 +1,20 @@ +package uploader_test + +import ( + "fmt" + "github.com/mylxsw/aidea-server/pkg/uploader" + "github.com/mylxsw/go-utils/must" + "testing" +) + +func TestQueryImageInfo(t *testing.T) { + info, err := uploader.QueryImageInfo("https://ssl.aicode.cc/ai-server/24/20230811/aigc14995226-1db0-ea85-6f1e-933d19ed01d6.png") + must.NoError(err) + + t.Log(info) +} + +func TestRemoveImageFilter(t *testing.T) { + fmt.Println(uploader.RemoveImageFilter("https://ssl.aicode.cc/ai-server/24/20230811/aigc14995226-1db0-ea85-6f1e-933d19ed01d6.png")) + fmt.Println(uploader.RemoveImageFilter("https://ssl.aicode.cc/ai-server/24/20230811/aigc14995226-1db0-ea85-6f1e-933d19ed01d6.png-thumb")) +} diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index c60dffc..a43f204 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -1311,7 +1311,21 @@ func (ctl *CreativeIslandController) ImageToVideo(ctx context.Context, webCtx we return webCtx.JSONError("invalid image", http.StatusBadRequest) } - image = uploader.BuildImageURLWithFilter(image, "resize1024x576", ctl.conf.StorageDomain) + imageFilter := "resize1024x576" + + // 查询图片信息 + info, err := uploader.QueryImageInfo(image) + if err == nil { + if info.Width == info.Height { + imageFilter = "resize768x768" + } else if info.Width > info.Height { + imageFilter = "resize1024x576" + } else { + imageFilter = "resize576x1024" + } + } + + image = uploader.BuildImageURLWithFilter(image, imageFilter, ctl.conf.StorageDomain) // 检查用户是否有足够的智慧果 quota, err := ctl.userSvc.UserQuota(ctx, user.ID) From d74720cdcdf7e7db4211e603fefe828442046a52 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Thu, 28 Dec 2023 17:26:13 +0800 Subject: [PATCH 10/13] =?UTF-8?q?=E8=A7=86=E9=A2=91=E5=A2=9E=E5=8A=A0=20wi?= =?UTF-8?q?dth/height=20=E5=B1=9E=E6=80=A7=E5=AD=98=E5=82=A8=EF=BC=8C?= =?UTF-8?q?=E6=96=B9=E4=BE=BF=E5=AE=A2=E6=88=B7=E7=AB=AF=E9=80=89=E6=8B=A9?= =?UTF-8?q?=E5=90=88=E9=80=82=E7=9A=84=E5=AE=BD=E9=AB=98=E6=AF=94=E5=B1=95?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/queue/image_to_video.go | 45 ++++++++++++------------ internal/queue/provider.go | 6 ++-- pkg/ai/stabilityai/video.go | 6 ++-- pkg/repo/creative.go | 2 ++ server/controllers/tasks.go | 32 +++++++++++------ server/controllers/v2/creative-island.go | 45 ++++++++++++++++++------ 6 files changed, 86 insertions(+), 50 deletions(-) diff --git a/internal/queue/image_to_video.go b/internal/queue/image_to_video.go index 036fd00..badfdb7 100644 --- a/internal/queue/image_to_video.go +++ b/internal/queue/image_to_video.go @@ -24,6 +24,19 @@ type ImageToVideoCompletionPayload struct { Seed int64 `json:"seed,omitempty"` Image string `json:"image,omitempty"` + Width int64 `json:"width,omitempty"` + Height int64 `json:"height,omitempty"` + + // CfgScale How strongly the video sticks to the original image. + // Use lower values to allow the model more freedom to make changes and higher values to correct motion distortions. + // number [ 0 .. 10 ], default 2.5 + CfgScale float64 `json:"cfg_scale,omitempty"` + // MotionBucketID Lower values generally result in less motion in the output video, + // while higher values generally result in more motion. + // This parameter corresponds to the motion_bucket_id parameter from the paper. + // number [ 1 .. 255 ], default 40 + MotionBucketID int `json:"motion_bucket_id,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` } @@ -81,15 +94,6 @@ func (p ImageToVideoPendingTaskPayload) GetModel() string { return p.Payload.GetModel() } -type ImageToVideoResponse interface { - GetID() string - GetState() string - IsFinished() bool - IsProcessing() bool - UploadResources(ctx context.Context, up *uploader.Uploader, uid int64) ([]string, error) - GetImages() []string -} - func BuildImageToVideoCompletionHandler( client *stabilityai.StabilityAI, rep *repo.Repository, @@ -142,8 +146,10 @@ func BuildImageToVideoCompletionHandler( defer os.Remove(targetImage) req := stabilityai.VideoRequest{ - ImagePath: targetImage, - Seed: int(payload.Seed), + ImagePath: targetImage, + Seed: int(payload.Seed), + CfgScale: payload.CfgScale, + MotionBucketID: payload.MotionBucketID, } resp, err := client.ImageToVideo(ctx, req) if err != nil { @@ -172,7 +178,7 @@ func BuildImageToVideoCompletionHandler( } } -func imageToVideoJobProcesser(que *Queue, client *stabilityai.StabilityAI, up *uploader.Uploader, rep *repo.Repository) PendingTaskHandler { +func imageToVideoJobProcesser(client *stabilityai.StabilityAI, up *uploader.Uploader, rep *repo.Repository) PendingTaskHandler { return func(task *model.QueueTasksPending) (update *repo.PendingTaskUpdate, err error) { var payload ImageToVideoPendingTaskPayload if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil { @@ -234,7 +240,7 @@ func imageToVideoJobProcesser(que *Queue, client *stabilityai.StabilityAI, up *u // 任务已经完成,开始处理结果 // 更新创作岛历史记录 - if err := handleImageToVideoTask(que, payload, taskRes, up, rep); err != nil { + if err := handleImageToVideoTask(&payload, taskRes, up, rep); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) return nil, err } @@ -243,17 +249,8 @@ func imageToVideoJobProcesser(que *Queue, client *stabilityai.StabilityAI, up *u } } -type ImageToVideoTaskPayload interface { - GetID() string - GetUID() int64 - GetQuota() int64 - GetModel() string - GetImage() string -} - func handleImageToVideoTask( - que *Queue, - payload ImageToVideoTaskPayload, + payload *ImageToVideoPendingTaskPayload, tasks *stabilityai.VideoResponse, up *uploader.Uploader, rep *repo.Repository, @@ -314,6 +311,8 @@ func handleImageToVideoTask( OriginImage: payload.GetImage(), Resources: resources, ValidBefore: time.Now().Add(7 * 24 * time.Hour), + Width: payload.Payload.Width, + Height: payload.Payload.Height, }, ) } diff --git a/internal/queue/provider.go b/internal/queue/provider.go index 6bbf671..c92fa2d 100644 --- a/internal/queue/provider.go +++ b/internal/queue/provider.go @@ -52,7 +52,7 @@ func (Provider) Boot(app infra.Resolver) { manager.Register(TypeLeapAICompletion, leapAsyncJobProcesser(leapClient, up, rep)) manager.Register(TypeFromStonCompletion, fromStonAsyncJobProcesser(queue, fromstonClient, up, rep)) manager.Register(TypeDashscopeImageCompletion, dashscopeImageAsyncJobProcesser(queue, dashscopeClient, up, rep)) - manager.Register(TypeImageToVideoCompletion, imageToVideoJobProcesser(queue, stabilityClient, up, rep)) + manager.Register(TypeImageToVideoCompletion, imageToVideoJobProcesser(stabilityClient, up, rep)) // 注册创作岛更新后,自动释放冻结的智慧果任务 rep.Creative.RegisterRecordStatusUpdateCallback(func(taskID string, userID int64, status repo.CreativeStatus) { @@ -108,7 +108,7 @@ func ResolveTaskType(category, model string) string { if model == "stability-image-to-video" { return TypeImageToVideoCompletion } - + return TypeStabilityAICompletion case "fromston": return TypeFromStonCompletion @@ -128,6 +128,8 @@ type CompletionResult struct { Resources []string `json:"resources"` OriginImage string `json:"origin_image,omitempty"` ValidBefore time.Time `json:"valid_before,omitempty"` + Width int64 `json:"width,omitempty"` + Height int64 `json:"height,omitempty"` } // ErrorResult 任务失败后的结果 diff --git a/pkg/ai/stabilityai/video.go b/pkg/ai/stabilityai/video.go index 7a92928..275133b 100644 --- a/pkg/ai/stabilityai/video.go +++ b/pkg/ai/stabilityai/video.go @@ -36,7 +36,7 @@ type VideoRequest struct { // CfgScale How strongly the video sticks to the original image. // Use lower values to allow the model more freedom to make changes and higher values to correct motion distortions. // number [ 0 .. 10 ], default 2.5 - CfgScale int `json:"cfg_scale,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` // MotionBucketID Lower values generally result in less motion in the output video, // while higher values generally result in more motion. // This parameter corresponds to the motion_bucket_id parameter from the paper. @@ -108,7 +108,7 @@ func (ai *StabilityAI) ImageToVideo(ctx context.Context, imageToVideoReq VideoRe h.Set("Content-Type", "image/png") imageWriter, _ := writer.CreatePart(h) - + imageFile, imageErr := os.Open(imageToVideoReq.ImagePath) if imageErr != nil { _ = writer.Close() @@ -122,7 +122,7 @@ func (ai *StabilityAI) ImageToVideo(ctx context.Context, imageToVideoReq VideoRe } if imageToVideoReq.CfgScale > 0 { - _ = writer.WriteField("cfg_scale", strconv.Itoa(imageToVideoReq.CfgScale)) + _ = writer.WriteField("cfg_scale", strconv.FormatFloat(imageToVideoReq.CfgScale, 'f', 2, 64)) } if imageToVideoReq.MotionBucketID > 0 { diff --git a/pkg/repo/creative.go b/pkg/repo/creative.go index 40f156e..68d25c6 100644 --- a/pkg/repo/creative.go +++ b/pkg/repo/creative.go @@ -676,6 +676,8 @@ type CreativeRecordArguments struct { Seed int64 `json:"seed,omitempty"` Text string `json:"text,omitempty"` ArtisticType string `json:"artistic_type,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` + MotionBucketID int `json:"motion_bucket_id,omitempty"` } func (arg CreativeRecordArguments) ToGalleryMeta() GalleryMeta { diff --git a/server/controllers/tasks.go b/server/controllers/tasks.go index 3052e60..91a05d6 100644 --- a/server/controllers/tasks.go +++ b/server/controllers/tasks.go @@ -3,7 +3,8 @@ package controllers import ( "context" "encoding/json" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "errors" + "github.com/mylxsw/aidea-server/pkg/repo" "github.com/mylxsw/aidea-server/pkg/youdao" "net/http" "time" @@ -19,7 +20,7 @@ import ( type TaskController struct { conf *config.Config - queueRepo *repo2.QueueRepo `autowire:"@"` + queueRepo *repo.QueueRepo `autowire:"@"` translater youdao.Translater `autowire:"@"` } @@ -31,16 +32,16 @@ func NewTaskController(resolver infra.Resolver, conf *config.Config) web.Control func (ctl *TaskController) Register(router web.Router) { router.Group("/tasks", func(router web.Router) { - router.Get("/{task_id}/status", ctl.taskStatus) + router.Get("/{task_id}/status", ctl.TaskStatus) }) } -// taskStatus 任务状态查询 -func (ctl *TaskController) taskStatus(ctx context.Context, webCtx web.Context, user *auth.User) web.Response { +// TaskStatus 任务状态查询 +func (ctl *TaskController) TaskStatus(ctx context.Context, webCtx web.Context, user *auth.User) web.Response { taskID := webCtx.PathVar("task_id") task, err := ctl.queueRepo.Task(ctx, taskID) if err != nil { - if err == repo2.ErrNotFound { + if errors.Is(err, repo.ErrNotFound) { return webCtx.JSONError(common.ErrNotFound, http.StatusNotFound) } return webCtx.JSONError(common.Text(webCtx, ctl.translater, common.ErrInternalError), http.StatusInternalServerError) @@ -50,22 +51,31 @@ func (ctl *TaskController) taskStatus(ctx context.Context, webCtx web.Context, u return webCtx.JSONError(common.Text(webCtx, ctl.translater, common.ErrNotFound), http.StatusNotFound) } - if repo2.QueueTaskStatus(task.Status) == repo2.QueueTaskStatusSuccess { + if repo.QueueTaskStatus(task.Status) == repo.QueueTaskStatusSuccess { var taskResult queue.CompletionResult if err := json.Unmarshal([]byte(task.Result), &taskResult); err != nil { log.With(task).Errorf("unmarshal task result failed: %v", err) return webCtx.JSONError(common.Text(webCtx, ctl.translater, common.ErrInternalError), http.StatusInternalServerError) } - - return webCtx.JSON(web.M{ + res := web.M{ "status": task.Status, "origin_image": taskResult.OriginImage, "resources": taskResult.Resources, "valid_before": taskResult.ValidBefore.Format(time.RFC3339), - }) + } + + if taskResult.Width > 0 { + res["width"] = taskResult.Width + } + + if taskResult.Height > 0 { + res["height"] = taskResult.Height + } + + return webCtx.JSON(res) } - if repo2.QueueTaskStatus(task.Status) == repo2.QueueTaskStatusFailed { + if repo.QueueTaskStatus(task.Status) == repo.QueueTaskStatusFailed { var errResult queue.ErrorResult if err := json.Unmarshal([]byte(task.Result), &errResult); err != nil { log.With(task).Errorf("unmarshal task result failed: %v", err) diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index a43f204..5cfd9b8 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -1311,21 +1311,21 @@ func (ctl *CreativeIslandController) ImageToVideo(ctx context.Context, webCtx we return webCtx.JSONError("invalid image", http.StatusBadRequest) } - imageFilter := "resize1024x576" + width, height := int64(1024), int64(576) // 查询图片信息 info, err := uploader.QueryImageInfo(image) if err == nil { if info.Width == info.Height { - imageFilter = "resize768x768" + width, height = 768, 768 } else if info.Width > info.Height { - imageFilter = "resize1024x576" + width, height = 1024, 576 } else { - imageFilter = "resize576x1024" + width, height = 576, 1024 } } - image = uploader.BuildImageURLWithFilter(image, imageFilter, ctl.conf.StorageDomain) + image = uploader.BuildImageURLWithFilter(image, fmt.Sprintf("resize%dx%d", width, height), ctl.conf.StorageDomain) // 检查用户是否有足够的智慧果 quota, err := ctl.userSvc.UserQuota(ctx, user.ID) @@ -1344,12 +1344,30 @@ func (ctl *CreativeIslandController) ImageToVideo(ctx context.Context, webCtx we seed = -1 } + // How strongly the video sticks to the original image. + // Use lower values to allow the model more freedom to make changes and higher values to correct motion distortions. + cfgScale := webCtx.Float64Input("cfg_scale", 2.5) + if cfgScale < 1 || cfgScale > 10 { + cfgScale = 2.5 + } + + // Lower values generally result in less motion in the output video, + // while higher values generally result in more motion + motionBucketID := webCtx.IntInput("motion_bucket_id", 40) + if motionBucketID < 1 || motionBucketID > 255 { + motionBucketID = 40 + } + req := queue.ImageToVideoCompletionPayload{ - Quota: quotaConsume, - CreatedAt: time.Now(), - Image: image, - UID: user.ID, - Seed: seed, + Quota: quotaConsume, + CreatedAt: time.Now(), + Image: image, + UID: user.ID, + Seed: seed, + CfgScale: cfgScale, + MotionBucketID: motionBucketID, + Width: width, + Height: height, } // 加入异步任务队列 @@ -1377,7 +1395,12 @@ func (ctl *CreativeIslandController) ImageToVideo(ctx context.Context, webCtx we } arg := repo.CreativeRecordArguments{ - Image: image, + Image: image, + Width: width, + Height: height, + Seed: seed, + MotionBucketID: motionBucketID, + CfgScale: cfgScale, } // 保存历史记录 From efcbd5737cd364a3e0e02df49c2f6d4f36d91cb9 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Fri, 29 Dec 2023 00:24:55 +0800 Subject: [PATCH 11/13] =?UTF-8?q?6pen=20=E6=8E=A5=E5=8F=A3=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=9B=BE=E7=89=87=E7=9B=B4=E6=8E=A5=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E4=B8=83=E7=89=9B=E4=BA=91=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E5=86=8D=E5=BC=82=E6=AD=A5=E5=88=9B=E5=BB=BA=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/queue/consumer/provider.go | 2 +- internal/queue/fromston.go | 125 +++++++++++++---------- internal/queue/image_downloader.go | 14 ++- internal/queue/provider.go | 3 +- server/controllers/v2/creative-island.go | 4 + 5 files changed, 86 insertions(+), 62 deletions(-) diff --git a/internal/queue/consumer/provider.go b/internal/queue/consumer/provider.go index 384f13a..cb5eab3 100644 --- a/internal/queue/consumer/provider.go +++ b/internal/queue/consumer/provider.go @@ -110,7 +110,7 @@ func (p Provider) Boot(resolver infra.Resolver) { mux.HandleFunc(queue.TypeFromStonCompletion, queue.BuildFromStonCompletionHandler(fromstonClient, uploader, rep)) mux.HandleFunc(queue.TypeDashscopeImageCompletion, queue.BuildDashscopeImageCompletionHandler(dashscopeClient, uploader, rep, translater, openaiClient)) mux.HandleFunc(queue.TypeGetimgAICompletion, queue.BuildGetimgAICompletionHandler(getimgaiClient, translater, uploader, rep, openaiClient)) - mux.HandleFunc(queue.TypeImageDownloader, queue.BuildImageDownloaderHandler(uploader, rep)) + mux.HandleFunc(queue.TypeImageDownloader, queue.BuildImageDownloaderHandler(conf, uploader, rep)) mux.HandleFunc(queue.TypeImageUpscale, queue.BuildImageUpscaleHandler(deepaiClient, stabaiClient, uploader, rep)) mux.HandleFunc(queue.TypeImageColorization, queue.BuildImageColorizationHandler(deepaiClient, uploader, rep)) mux.HandleFunc(queue.TypeGroupChat, queue.BuildGroupChatHandler(conf, ct, rep, userSvc)) diff --git a/internal/queue/fromston.go b/internal/queue/fromston.go index 834cb82..1d9eb94 100644 --- a/internal/queue/fromston.go +++ b/internal/queue/fromston.go @@ -5,11 +5,12 @@ import ( "encoding/json" "errors" "fmt" - fromston2 "github.com/mylxsw/aidea-server/pkg/ai/fromston" + "github.com/mylxsw/aidea-server/config" + "github.com/mylxsw/aidea-server/pkg/ai/fromston" "github.com/mylxsw/aidea-server/pkg/misc" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/pkg/repo" "github.com/mylxsw/aidea-server/pkg/repo/model" - uploader2 "github.com/mylxsw/aidea-server/pkg/uploader" + "github.com/mylxsw/aidea-server/pkg/uploader" "os" "strconv" "strings" @@ -105,11 +106,11 @@ type FromStonResponse interface { GetState() string IsFinished() bool IsProcessing() bool - UploadResources(ctx context.Context, up *uploader2.Uploader, uid int64) ([]string, error) + UploadResources(ctx context.Context, up *uploader.Uploader, uid int64) ([]string, error) GetImages() []string } -func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Uploader, rep *repo2.Repository) TaskHandler { +func BuildFromStonCompletionHandler(client *fromston.Fromston, up *uploader.Uploader, rep *repo.Repository) TaskHandler { return func(ctx context.Context, task *asynq.Task) (err error) { var payload FromStonCompletionPayload if err := json.Unmarshal(task.Payload(), &payload); err != nil { @@ -117,7 +118,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up } if payload.CreatedAt.Add(5 * time.Minute).Before(time.Now()) { - rep.Queue.Update(context.TODO(), payload.GetID(), repo2.QueueTaskStatusFailed, ErrorResult{Errors: []string{"任务处理超时"}}) + rep.Queue.Update(context.TODO(), payload.GetID(), repo.QueueTaskStatusFailed, ErrorResult{Errors: []string{"任务处理超时"}}) log.WithFields(log.Fields{"payload": payload}).Errorf("task expired") return nil } @@ -128,9 +129,9 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up err = err2.(error) // 更新创作岛历史记录 - if err := rep.Creative.UpdateRecordByTaskID(ctx, payload.GetUID(), payload.GetID(), repo2.CreativeRecordUpdateRequest{ + if err := rep.Creative.UpdateRecordByTaskID(ctx, payload.GetUID(), payload.GetID(), repo.CreativeRecordUpdateRequest{ Answer: err.Error(), - Status: repo2.CreativeStatusFailed, + Status: repo.CreativeStatusFailed, }); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) } @@ -140,7 +141,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up if err := rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusFailed, + repo.QueueTaskStatusFailed, ErrorResult{ Errors: []string{err.Error()}, }, @@ -155,7 +156,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up // 如果本地下载失败,则直接发送远程图片地址到 Leap localImagePath := payload.Image if payload.Image != "" { - imagePath, err := uploader2.DownloadRemoteFile(ctx, payload.Image) + imagePath, err := uploader.DownloadRemoteFile(ctx, payload.Image) if err != nil { log.WithFields(log.Fields{ "payload": payload, @@ -203,10 +204,10 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up modelType, modelIdStr := ms[0], ms[1] - var resp *fromston2.GenImageResponseData + var resp *fromston.GenImageResponseData if modelType == "custom" { // 自己训练的模型 - req := fromston2.GenImageCustomRequest{ + req := fromston.GenImageCustomRequest{ Prompt: prompt, FillPrompt: int64(ternary.If(payload.AIRewrite, 1, 0)), Width: payload.Width, @@ -214,7 +215,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up RefImg: localImagePath, ModelID: modelIdStr, Multiply: payload.ImageCount, - Addition: &fromston2.GenImageAddition{ + Addition: &fromston.GenImageAddition{ ImgFmt: "jpg", NegativePrompt: negativePrompt, Strength: payload.ImageStrength, @@ -233,7 +234,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up panic(fmt.Errorf("invalid model: %s", payload.Model)) } - req := fromston2.GenImageRequest{ + req := fromston.GenImageRequest{ Prompt: prompt, FillPrompt: int64(ternary.If(payload.AIRewrite, 1, 0)), Width: payload.Width, @@ -242,7 +243,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up ModelType: ms[0], ModelID: int64(modelId), Multiply: payload.ImageCount, - Addition: &fromston2.GenImageAddition{ + Addition: &fromston.GenImageAddition{ ImgFmt: "jpg", NegativePrompt: negativePrompt, Strength: payload.ImageStrength, @@ -267,7 +268,7 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up } if prompt != payload.Prompt || negativePrompt != payload.NegativePrompt { - argUpdate := repo2.CreativeRecordUpdateExtArgs{} + argUpdate := repo.CreativeRecordUpdateExtArgs{} if prompt != payload.Prompt { argUpdate.RealPrompt = prompt } @@ -281,12 +282,12 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up } } - if err := rep.Queue.CreatePendingTask(ctx, &repo2.PendingTask{ + if err := rep.Queue.CreatePendingTask(ctx, &repo.PendingTask{ TaskID: payload.GetID(), TaskType: TypeFromStonCompletion, NextExecuteAt: time.Now().Add(time.Duration(estimates) * time.Second), DeadlineAt: time.Now().Add(30 * time.Minute), - Status: repo2.PendingTaskStatusProcessing, + Status: repo.PendingTaskStatusProcessing, Payload: FromStonPendingTaskPayload{FromstonTaskIDs: resp.IDs, Payload: payload, ModelType: modelType}, }); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("create pending task failed: %s", err) @@ -296,20 +297,20 @@ func BuildFromStonCompletionHandler(client *fromston2.Fromston, up *uploader2.Up return rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusRunning, + repo.QueueTaskStatusRunning, nil, ) } } -func fromStonAsyncJobProcesser(que *Queue, client *fromston2.Fromston, up *uploader2.Uploader, rep *repo2.Repository) PendingTaskHandler { - return func(task *model.QueueTasksPending) (update *repo2.PendingTaskUpdate, err error) { +func fromStonAsyncJobProcesser(conf *config.Config, que *Queue, client *fromston.Fromston, up *uploader.Uploader, rep *repo.Repository) PendingTaskHandler { + return func(task *model.QueueTasksPending) (update *repo.PendingTaskUpdate, err error) { var payload FromStonPendingTaskPayload if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil { return nil, err } - var tasks []fromston2.Task + var tasks []fromston.Task if payload.ModelType == "custom" { // 自己训练的模型任务查询 for _, id := range payload.FromstonTaskIDs { @@ -325,9 +326,9 @@ func fromStonAsyncJobProcesser(que *Queue, client *fromston2.Fromston, up *uploa tasks, err = client.QueryTasks(context.TODO(), payload.FromstonTaskIDs) if err != nil { log.With(payload).Errorf("query fromston job result failed: %v", err) - return &repo2.PendingTaskUpdate{ + return &repo.PendingTaskUpdate{ NextExecuteAt: time.Now().Add(5 * time.Second), - Status: repo2.PendingTaskStatusProcessing, + Status: repo.PendingTaskStatusProcessing, ExecuteTimes: task.ExecuteTimes + 1, }, nil } @@ -339,21 +340,21 @@ func fromStonAsyncJobProcesser(que *Queue, client *fromston2.Fromston, up *uploa err = err2.(error) // 更新创作岛历史记录 - if err := rep.Creative.UpdateRecordByTaskID(context.TODO(), payload.Payload.GetUID(), payload.Payload.GetID(), repo2.CreativeRecordUpdateRequest{ + if err := rep.Creative.UpdateRecordByTaskID(context.TODO(), payload.Payload.GetUID(), payload.Payload.GetID(), repo.CreativeRecordUpdateRequest{ Answer: err.Error(), - Status: repo2.CreativeStatusFailed, + Status: repo.CreativeStatusFailed, }); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) } - update = &repo2.PendingTaskUpdate{Status: repo2.PendingTaskStatusFailed} + update = &repo.PendingTaskUpdate{Status: repo.PendingTaskStatusFailed} } if err != nil { if err := rep.Queue.Update( context.TODO(), payload.Payload.GetID(), - repo2.QueueTaskStatusFailed, + repo.QueueTaskStatusFailed, ErrorResult{ Errors: []string{err.Error()}, }, @@ -363,28 +364,28 @@ func fromStonAsyncJobProcesser(que *Queue, client *fromston2.Fromston, up *uploa } }() - unfinishedTask := array.Filter(tasks, func(item fromston2.Task, _ int) bool { + unfinishedTask := array.Filter(tasks, func(item fromston.Task, _ int) bool { return array.In(item.State, []string{"in_wait", "in_create"}) }) if len(unfinishedTask) > 0 { - return &repo2.PendingTaskUpdate{ + return &repo.PendingTaskUpdate{ NextExecuteAt: time.Now().Add(5 * time.Second), - Status: repo2.PendingTaskStatusProcessing, + Status: repo.PendingTaskStatusProcessing, ExecuteTimes: task.ExecuteTimes + 1, }, nil } // 任务已经完成,开始处理结果 - successTasks := array.Filter(tasks, func(item fromston2.Task, _ int) bool { + successTasks := array.Filter(tasks, func(item fromston.Task, _ int) bool { return item.State == "success" }) if len(successTasks) == 0 { log.WithFields(log.Fields{"payload": payload, "tasks": tasks}).Errorf("no success task found") - failedTasks := array.Filter(tasks, func(item fromston2.Task, _ int) bool { return item.State == "fail" }) + failedTasks := array.Filter(tasks, func(item fromston.Task, _ int) bool { return item.State == "fail" }) if len(failedTasks) > 0 { - panic(errors.New(strings.Join(array.Map(failedTasks, func(t fromston2.Task, _ int) string { + panic(errors.New(strings.Join(array.Map(failedTasks, func(t fromston.Task, _ int) string { switch t.FailReson { case "NSFW": return "检测到违规内容,请修改后重试" @@ -399,12 +400,12 @@ func fromStonAsyncJobProcesser(que *Queue, client *fromston2.Fromston, up *uploa } // 更新创作岛历史记录 - if err := handleFromstonTask(que, payload, successTasks, up, rep); err != nil { + if err := handleFromstonTask(conf, que, payload, successTasks, up, rep); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) return nil, err } - return &repo2.PendingTaskUpdate{Status: repo2.PendingTaskStatusSuccess}, nil + return &repo.PendingTaskUpdate{Status: repo.PendingTaskStatusSuccess}, nil } } @@ -417,14 +418,23 @@ type FromstonTaskPayload interface { } func handleFromstonTask( + conf *config.Config, que *Queue, payload FromstonTaskPayload, - tasks []fromston2.Task, - up *uploader2.Uploader, - rep *repo2.Repository, + tasks []fromston.Task, + up *uploader.Uploader, + rep *repo.Repository, ) error { - resources := array.Map(tasks, func(item fromston2.Task, _ int) string { - return item.GenImg + resources := array.Map(tasks, func(item fromston.Task, _ int) string { + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + res, err := up.UploadRemoteFile(ctx, item.GenImg, int(payload.GetUID()), uploader.DefaultUploadExpireAfterDays, "png", false) + if err != nil { + return item.GenImg + } + + return res }) resources = array.Filter(resources, func(item string, _ int) bool { return item != "" }) @@ -446,10 +456,10 @@ func handleFromstonTask( // quotaConsumed := coins.GetFromstonImageCoins(payload.GetModel(), isCsMode, width, height) * int64(len(resources)) quotaConsumed := int64(coins.GetUnifiedImageGenCoins("") * len(resources)) - req := repo2.CreativeRecordUpdateRequest{ + req := repo.CreativeRecordUpdateRequest{ Answer: string(retJson), QuotaUsed: quotaConsumed, - Status: repo2.CreativeStatusSuccess, + Status: repo.CreativeStatusSuccess, } if err := rep.Creative.UpdateRecordByTaskID(context.TODO(), payload.GetUID(), payload.GetID(), req); err != nil { log.WithFields(log.Fields{"payload": payload}).Errorf("update creative failed: %s", err) @@ -462,30 +472,33 @@ func handleFromstonTask( context.TODO(), payload.GetUID(), payload.GetQuota(), - repo2.NewQuotaUsedMeta("fromston", modelUsed...), + repo.NewQuotaUsedMeta("fromston", modelUsed...), ); err != nil { log.Errorf("used quota add failed: %s", err) return err } - // 触发文件下载上传七牛云任务 - downloadPayload := ImageDownloaderPayload{ - CreativeHistoryTaskID: payload.GetID(), - UserID: payload.GetUID(), - CreatedAt: time.Now(), - } - downloadTaskID, err := que.Enqueue(&downloadPayload, NewImageDownloaderTask) - if err != nil { - log.WithFields(log.Fields{"payload": payload}).Errorf("enqueue image downloader task failed: %s", err) - } else { - log.WithFields(log.Fields{"payload": payload, "task_id": downloadTaskID}).Debugf("enqueue image downloader task success") + needDownloadResources := array.Filter(resources, func(item string, _ int) bool { return !strings.HasPrefix(item, conf.StorageDomain) }) + if len(needDownloadResources) > 0 { + // 触发文件下载上传七牛云任务 + downloadPayload := ImageDownloaderPayload{ + CreativeHistoryTaskID: payload.GetID(), + UserID: payload.GetUID(), + CreatedAt: time.Now(), + } + downloadTaskID, err := que.Enqueue(&downloadPayload, NewImageDownloaderTask) + if err != nil { + log.WithFields(log.Fields{"payload": payload}).Errorf("enqueue image downloader task failed: %s", err) + } else { + log.WithFields(log.Fields{"payload": payload, "task_id": downloadTaskID}).Debugf("enqueue image downloader task success") + } } // 更新队列任务状态 return rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusSuccess, + repo.QueueTaskStatusSuccess, CompletionResult{ OriginImage: payload.GetImage(), Resources: resources, diff --git a/internal/queue/image_downloader.go b/internal/queue/image_downloader.go index 62293d9..588b737 100644 --- a/internal/queue/image_downloader.go +++ b/internal/queue/image_downloader.go @@ -4,8 +4,10 @@ import ( "context" "encoding/json" "fmt" - repo2 "github.com/mylxsw/aidea-server/pkg/repo" + "github.com/mylxsw/aidea-server/config" + "github.com/mylxsw/aidea-server/pkg/repo" "github.com/mylxsw/aidea-server/pkg/uploader" + "strings" "time" "github.com/hibiken/asynq" @@ -48,7 +50,7 @@ func NewImageDownloaderTask(payload any) *asynq.Task { return asynq.NewTask(TypeImageDownloader, data) } -func BuildImageDownloaderHandler(up *uploader.Uploader, rep *repo2.Repository) TaskHandler { +func BuildImageDownloaderHandler(conf *config.Config, up *uploader.Uploader, rep *repo.Repository) TaskHandler { return func(ctx context.Context, task *asynq.Task) (err error) { var payload ImageDownloaderPayload if err := json.Unmarshal(task.Payload(), &payload); err != nil { @@ -72,7 +74,7 @@ func BuildImageDownloaderHandler(up *uploader.Uploader, rep *repo2.Repository) T if err := rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusFailed, + repo.QueueTaskStatusFailed, ErrorResult{ Errors: []string{err.Error()}, }, @@ -103,6 +105,10 @@ func BuildImageDownloaderHandler(up *uploader.Uploader, rep *repo2.Repository) T } for i, res := range resources { + if strings.HasPrefix(res, conf.StorageDomain) { + continue + } + ret, err := up.UploadRemoteFile(ctx, res, int(payload.UserID), uploader.DefaultUploadExpireAfterDays, "png", false) if err != nil { log.WithFields(log.Fields{ @@ -124,7 +130,7 @@ func BuildImageDownloaderHandler(up *uploader.Uploader, rep *repo2.Repository) T return rep.Queue.Update( context.TODO(), payload.GetID(), - repo2.QueueTaskStatusSuccess, + repo.QueueTaskStatusSuccess, EmptyResult{}, ) } diff --git a/internal/queue/provider.go b/internal/queue/provider.go index c92fa2d..4602759 100644 --- a/internal/queue/provider.go +++ b/internal/queue/provider.go @@ -44,13 +44,14 @@ func (Provider) Boot(app infra.Resolver) { stabilityClient *stabilityai.StabilityAI, up *uploader.Uploader, queue *Queue, + conf *config.Config, rep *repo.Repository, userSvc *service.UserService, rds *redis.Client, ) { // 注册异步 PendingTask 任务处理器 manager.Register(TypeLeapAICompletion, leapAsyncJobProcesser(leapClient, up, rep)) - manager.Register(TypeFromStonCompletion, fromStonAsyncJobProcesser(queue, fromstonClient, up, rep)) + manager.Register(TypeFromStonCompletion, fromStonAsyncJobProcesser(conf, queue, fromstonClient, up, rep)) manager.Register(TypeDashscopeImageCompletion, dashscopeImageAsyncJobProcesser(queue, dashscopeClient, up, rep)) manager.Register(TypeImageToVideoCompletion, imageToVideoJobProcesser(stabilityClient, up, rep)) diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index 5cfd9b8..8e0bcf3 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -488,6 +488,10 @@ func (ctl *CreativeIslandController) Histories(ctx context.Context, webCtx web.C item.IslandTitle = "高清修复" case int64(repo.IslandTypeImageColorization): item.IslandTitle = "图片上色" + case int64(repo.IslandTypeVideo): + item.IslandTitle = "图生视频" + case int64(repo.IslandTypeArtisticText): + item.IslandTitle = "艺术字" } // 客户端目前不支持封禁状态展示,这里转换为失败 From 7d40582e2f3e5551a65972de104aaf2869571944 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Fri, 29 Dec 2023 14:09:09 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=88=9B=E4=BD=9C?= =?UTF-8?q?=E5=B2=9B=E6=96=87=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/controllers/v2/creative-island.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/server/controllers/v2/creative-island.go b/server/controllers/v2/creative-island.go index 8e0bcf3..4491266 100644 --- a/server/controllers/v2/creative-island.go +++ b/server/controllers/v2/creative-island.go @@ -129,7 +129,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-text-to-image.jpeg-thumb1000", RouteURI: "/creative-draw/create?mode=text-to-image&id=text-to-image", - Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果%s。", imageCost, imageModelsCostNote), + Note: fmt.Sprintf("根据你的想法生成图片。生成每张图片将消耗 %d 智慧果%s。", imageCost, imageModelsCostNote), Size: SizeLarge, }, } @@ -141,7 +141,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-to-video-dark.jpg-thumb1000", RouteURI: "/creative-draw/create-video", - Note: fmt.Sprintf("图生视频功能能够将静态的图片转换为动态的视频,转换后的视频时长为 2s。生成每个视频将消耗 %d 智慧果。", videoCost), + Note: fmt.Sprintf("基于上传的图片再创作,生成一个时长为 2s 的短视频。生成每个视频将消耗 %d 智慧果。", videoCost), Size: SizeLarge, }) } @@ -153,7 +153,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/art-text-bg.jpg-thumb1000", RouteURI: "/creative-draw/artistic-text?type=text&id=artistic-text", - Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), + Note: fmt.Sprintf("根据你的想法生成图片,并且在图片中融入你写的文字内容。生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeLarge, }) items = append(items, CreativeIslandItem{ @@ -162,7 +162,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/art-qr-bg.jpg-thumb1000", RouteURI: "/creative-draw/artistic-text?type=qr&id=artistic-qr", - Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), + Note: fmt.Sprintf("根据你的想法生成图片,并且将链接地址转换为二维码,把图片和二维码融合到一起。生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) } @@ -174,7 +174,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-image-to-image.jpeg-thumb1000", RouteURI: "/creative-draw/create?mode=image-to-image&id=image-to-image", Tag: ternary.If(client != nil && client.IsIOS(), "", "BETA"), - Note: fmt.Sprintf("生成每张图片将消耗 %d 智慧果。", imageCost), + Note: fmt.Sprintf("基于参考图片的轮廓,为你生成一张整体结构类似的图片。生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) @@ -185,7 +185,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/super-res.jpeg-thumb1000", RouteURI: "/creative-draw/create-upscale", - Note: fmt.Sprintf("图片的高清修复功能能够把低分辨率的照片升级到高分辨率,让图片的清晰度得到明显提升。\n生成每张图片将消耗 %d 智慧果。", imageCost), + Note: fmt.Sprintf("将低分辨率的照片升级到高分辨率,让图片的清晰度得到明显提升。\n生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) @@ -195,7 +195,7 @@ func (ctl *CreativeIslandController) Items(ctx context.Context, webCtx web.Conte TitleColor: "FFFFFFFF", PreviewImage: "https://ssl.aicode.cc/ai-server/assets/background/image-colorizev2.jpeg-thumb1000", RouteURI: "/creative-draw/create-colorize", - Note: fmt.Sprintf("图片上色功能能够把黑白照片变成彩色照片,让照片的色彩更加丰富。\n生成每张图片将消耗 %d 智慧果。", imageCost), + Note: fmt.Sprintf("将黑白照片变成彩色照片,让照片的色彩更加丰富。\n生成每张图片将消耗 %d 智慧果。", imageCost), Size: SizeMedium, }) } From 1d550f264ad9ebd17bab59aae7e4172573a272b0 Mon Sep 17 00:00:00 2001 From: mylxsw Date: Thu, 4 Jan 2024 11:49:46 +0800 Subject: [PATCH 13/13] =?UTF-8?q?bugfix=20=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=90=8E=EF=BC=8C=E5=88=9B=E5=BB=BA=E7=9A=84?= =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AA=E6=95=B0=E5=AD=97=E4=BA=BA=E7=94=B1?= =?UTF-8?q?=E4=BA=8E=20ID=3D1=EF=BC=8C=E8=B7=9F=E8=81=8A=E4=B8=80=E8=81=8A?= =?UTF-8?q?=E9=87=8D=E5=90=88=E4=BA=86=EF=BC=8C=E5=AF=BC=E8=87=B4=E5=AE=A2?= =?UTF-8?q?=E6=88=B7=E7=AB=AF=E5=86=85=E5=AE=B9=E6=98=BE=E7=A4=BA=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- migrate/data/20231129_ddl.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrate/data/20231129_ddl.go b/migrate/data/20231129_ddl.go index 5dd445e..906b53b 100644 --- a/migrate/data/20231129_ddl.go +++ b/migrate/data/20231129_ddl.go @@ -468,7 +468,7 @@ func Migrate20231129DDL(m *migrate.Manager) { room_type TINYINT NULL COMMENT '房间类型:1-系统预设 2-自定义', init_message TEXT NULL COMMENT '初始化消息', avatar_url VARCHAR(255) NULL COMMENT ' 头像地址' -) CHARSET = utf8mb4 COLLATE = utf8mb4_general_ci`, +) CHARSET = utf8mb4 COLLATE = utf8mb4_general_ci AUTO_INCREMENT = 2`, `CREATE INDEX rooms_user_id ON rooms (user_id)`, } })