Skip to content

Commit 01eb9cc

Browse files
authored
[Fleet] Replace findByApiKeyID with a get call when possible (#3611)
1 parent 648d40b commit 01eb9cc

File tree

8 files changed

+75
-7
lines changed

8 files changed

+75
-7
lines changed

internal/pkg/api/auth.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,13 @@ func authAgent(r *http.Request, id *string, bulker bulk.Bulk, c cache.Cache) (*m
116116
Msg("authApiKey slow")
117117
}
118118

119-
agent, err := findAgentByAPIKeyID(ctx, bulker, key.ID)
119+
var agent *model.Agent
120+
// If we have the agentID retrieve the agent document with a get (more performant) instead of triggering a search
121+
if id != nil {
122+
agent, err = getAgentAndVerifyAPIKeyID(ctx, bulker, *id, key.ID)
123+
} else {
124+
agent, err = findAgentByAPIKeyID(ctx, bulker, key.ID)
125+
}
120126
if err != nil {
121127
return nil, err
122128
}

internal/pkg/api/handleCheckin.go

+19
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,25 @@ func processPolicy(ctx context.Context, zlog zerolog.Logger, bulker bulk.Bulk, a
855855
return &resp, nil
856856
}
857857

858+
func getAgentAndVerifyAPIKeyID(ctx context.Context, bulker bulk.Bulk, agentID string, apiKeyID string) (*model.Agent, error) {
859+
span, ctx := apm.StartSpan(ctx, "getAgentAndVerifyAPIKeyID", "read")
860+
defer span.End()
861+
agent, err := dl.GetAgent(ctx, bulker, agentID)
862+
if err != nil {
863+
if errors.Is(err, dl.ErrNotFound) {
864+
err = ErrAgentNotFound
865+
} else {
866+
err = fmt.Errorf("GetAgent: %w", err)
867+
}
868+
}
869+
870+
if agent.AccessAPIKeyID != apiKeyID {
871+
err = fmt.Errorf("invalid API Key ID %w", ErrAgentIdentity)
872+
}
873+
874+
return &agent, err
875+
}
876+
858877
func findAgentByAPIKeyID(ctx context.Context, bulker bulk.Bulk, id string) (*model.Agent, error) {
859878
span, ctx := apm.StartSpan(ctx, "findAgentByID", "search")
860879
defer span.End()

internal/pkg/bulk/engine.go

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type Bulk interface {
4747
// Synchronous operations run in the bulk engine
4848
Create(ctx context.Context, index, id string, body []byte, opts ...Opt) (string, error)
4949
Read(ctx context.Context, index, id string, opts ...Opt) ([]byte, error)
50+
ReadRaw(ctx context.Context, index, id string, opts ...Opt) (*MgetResponseItem, error)
5051
Update(ctx context.Context, index, id string, body []byte, opts ...Opt) error
5152
Delete(ctx context.Context, index, id string, opts ...Opt) error
5253
Index(ctx context.Context, index, id string, body []byte, opts ...Opt) (string, error)

internal/pkg/bulk/opRead.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ const (
2121
rSuffix = "]}"
2222
)
2323

24-
func (b *Bulker) Read(ctx context.Context, index, id string, opts ...Opt) ([]byte, error) {
25-
span, ctx := apm.StartSpan(ctx, "Bulker: read", "bulker")
24+
func (b *Bulker) ReadRaw(ctx context.Context, index, id string, opts ...Opt) (*MgetResponseItem, error) {
25+
span, ctx := apm.StartSpan(ctx, "Bulker: readRaw", "bulker")
2626
defer span.End()
2727
opt := b.parseOpts(append(opts, withAPMLinkedContext(ctx))...)
2828
blk := b.newBlk(ActionRead, opt)
@@ -49,6 +49,17 @@ func (b *Bulker) Read(ctx context.Context, index, id string, opts ...Opt) ([]byt
4949
if !ok {
5050
return nil, fmt.Errorf("unable to cast response to *MgetResponseItem, detected type: %T", resp.data)
5151
}
52+
return r, nil
53+
}
54+
55+
func (b *Bulker) Read(ctx context.Context, index, id string, opts ...Opt) ([]byte, error) {
56+
span, ctx := apm.StartSpan(ctx, "Bulker: read", "bulker")
57+
defer span.End()
58+
r, err := b.ReadRaw(ctx, index, id, opts...)
59+
if err != nil {
60+
return nil, err
61+
}
62+
5263
return r.Source, nil
5364
}
5465

internal/pkg/bulk/schema.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ type MgetResponse struct {
8080
type MgetResponseItem struct {
8181
// Index string `json:"_index"`
8282
// Type string `json:"_type"`
83-
// DocumentID string `json:"_id"`
84-
// Version int64 `json:"_version"`
85-
// SeqNo int64 `json:"_seq_no"`
83+
DocumentID string `json:"_id"`
84+
Version int64 `json:"_version"`
85+
SeqNo int64 `json:"_seq_no"`
8686
// PrimTerm int64 `json:"_primary_term"`
8787
Found bool `json:"found"`
8888
// Routing string `json:"_routing"`

internal/pkg/dl/agent.go

+26
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@ package dl
66

77
import (
88
"context"
9+
"encoding/json"
10+
"errors"
911
"fmt"
1012

1113
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
1214
"github.com/elastic/fleet-server/v7/internal/pkg/dsl"
15+
"github.com/elastic/fleet-server/v7/internal/pkg/es"
1316
"github.com/elastic/fleet-server/v7/internal/pkg/model"
1417
)
1518

@@ -39,6 +42,29 @@ func prepareAgentFindByField(field string) *dsl.Tmpl {
3942
return prepareFindByField(field, map[string]interface{}{"version": true})
4043
}
4144

45+
func GetAgent(ctx context.Context, bulker bulk.Bulk, agentID string, opt ...Option) (model.Agent, error) {
46+
o := newOption(FleetAgents, opt...)
47+
var agent model.Agent
48+
data, err := bulker.ReadRaw(ctx, o.indexName, agentID)
49+
if err != nil {
50+
if errors.Is(err, es.ErrElasticNotFound) {
51+
return model.Agent{}, ErrNotFound
52+
} else {
53+
return model.Agent{}, err
54+
}
55+
}
56+
err = json.Unmarshal(data.Source, &agent)
57+
if err != nil {
58+
return model.Agent{}, err
59+
}
60+
61+
agent.Id = agentID
62+
agent.SeqNo = data.SeqNo
63+
agent.Version = data.Version
64+
65+
return agent, err
66+
}
67+
4268
func FindAgent(ctx context.Context, bulker bulk.Bulk, tmpl *dsl.Tmpl, name string, v interface{}, opt ...Option) (model.Agent, error) {
4369
o := newOption(FleetAgents, opt...)
4470
res, err := SearchWithOneParam(ctx, bulker, tmpl, o.indexName, name, v)

internal/pkg/server/fleet_integration_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ func Test_Agent_Auth_errors(t *testing.T) {
897897
res, err := cli.Do(req)
898898
require.NoError(t, err)
899899
res.Body.Close()
900-
require.Equal(t, http.StatusNotFound, res.StatusCode) // NOTE this is a 404 and not a 400
900+
require.Equal(t, http.StatusForbidden, res.StatusCode)
901901
})
902902
t.Run("wrong agent ID", func(t *testing.T) {
903903
ctx := testlog.SetLogger(t).WithContext(ctx)

internal/pkg/testing/bulk.go

+5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ func (m *MockBulk) Read(ctx context.Context, index, id string, opts ...bulk.Opt)
4545
return args.Get(0).([]byte), args.Error(1)
4646
}
4747

48+
func (m *MockBulk) ReadRaw(ctx context.Context, index, id string, opts ...bulk.Opt) (*bulk.MgetResponseItem, error) {
49+
args := m.Called(ctx, index, id, opts)
50+
return args.Get(0).(*bulk.MgetResponseItem), args.Error(1)
51+
}
52+
4853
func (m *MockBulk) Delete(ctx context.Context, index, id string, opts ...bulk.Opt) error {
4954
args := m.Called(ctx, index, id, opts)
5055
return args.Error(0)

0 commit comments

Comments
 (0)