Skip to content

Commit 20677ba

Browse files
committed
refactor for rag
1 parent 28ce7ab commit 20677ba

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# Output of the go coverage tool, specifically when used with LiteIDE
1919
*.out
20+
*.out.*
2021

2122
# Dependency directories (remove the comment below to include it)
2223
# vendor/

rag.go

+26-13
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,39 @@ import (
44
"context"
55
_ "embed"
66
"fmt"
7+
"os"
78
"time"
89

910
"github.com/philippgille/chromem-go"
1011
"github.com/sashabaranov/go-openai"
1112
)
1213

14+
type OpenAIConfig struct {
15+
EmbedModel string
16+
Endpoint string
17+
LLMModel string
18+
Token string
19+
}
20+
1321
type RAG struct {
1422
db *chromem.DB
1523

16-
embedModel string
17-
endpoint string
18-
llmModel string
19-
token string
24+
config OpenAIConfig
2025
embeddingFunc chromem.EmbeddingFunc
2126
}
2227

23-
func NewRAG(filename string, embedModel string, llmModel string, endpoint string, token string) (*RAG, error) {
28+
func NewRAG(filename string, config *OpenAIConfig) (*RAG, error) {
29+
if config == nil {
30+
config = &OpenAIConfig{
31+
// https://platform.openai.com/docs/guides/embeddings#embedding-models
32+
EmbedModel: "text-embedding-3-small",
33+
Endpoint: "https://api.openai.com/v1",
34+
// https://platform.openai.com/docs/model
35+
LLMModel: "gpt-4o-mini",
36+
Token: os.Getenv("OPENAI_API_KEY"),
37+
}
38+
}
39+
2440
db := chromem.NewDB()
2541

2642
if filename != ":memory:" {
@@ -35,11 +51,8 @@ func NewRAG(filename string, embedModel string, llmModel string, endpoint string
3551
return &RAG{
3652
db: db,
3753

38-
embedModel: embedModel,
39-
endpoint: endpoint,
40-
llmModel: llmModel,
41-
token: token,
42-
embeddingFunc: chromem.NewEmbeddingFuncOpenAICompat(endpoint, token, embedModel, nil),
54+
config: *config,
55+
embeddingFunc: chromem.NewEmbeddingFuncOpenAICompat(config.Endpoint, config.Token, config.EmbedModel, nil),
4356
}, nil
4457
}
4558

@@ -89,8 +102,8 @@ func (r *RAG) Ask(query string) (string, error) {
89102
return "", fmt.Errorf("failed to search: %w", err)
90103
}
91104

92-
config := openai.DefaultConfig(r.token)
93-
config.BaseURL = r.endpoint
105+
config := openai.DefaultConfig(r.config.Token)
106+
config.BaseURL = r.config.Endpoint
94107
client := openai.NewClientWithConfig(config)
95108

96109
userPrompt := "Query: " + query + "\n\nDocuments:\n\n"
@@ -103,7 +116,7 @@ func (r *RAG) Ask(query string) (string, error) {
103116
response, err := client.CreateChatCompletion(
104117
context.Background(),
105118
openai.ChatCompletionRequest{
106-
Model: r.llmModel,
119+
Model: r.config.LLMModel,
107120
Messages: []openai.ChatCompletionMessage{
108121
{
109122
Role: "system",

rag_test.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ The Internet has transformed:
7272

7373
var _ = FDescribe("RAG", func() {
7474
It("adds documents", func() {
75-
rag, err := builder.NewRAG(":memory:", "nomic-embed-text", "llama3.2", "http://localhost:11434/v1", "")
75+
config := &builder.OpenAIConfig{
76+
EmbedModel: "nomic-embed-text",
77+
Endpoint: "http://localhost:11434/v1",
78+
LLMModel: "llama3.2",
79+
Token: "",
80+
}
81+
rag, err := builder.NewRAG(":memory:", config)
7682
Expect(err).NotTo(HaveOccurred())
7783

7884
for index, doc := range []string{doc1, doc2, doc3} {
@@ -86,6 +92,6 @@ var _ = FDescribe("RAG", func() {
8692

8793
answer, err := rag.Ask("What is the largest planet?")
8894
Expect(err).NotTo(HaveOccurred())
89-
fmt.Println(answer)
95+
fmt.Println("answer:" + answer)
9096
})
9197
})

0 commit comments

Comments
 (0)