Skip to content

Commit

Permalink
Feature: cache endpoint responses
Browse files Browse the repository at this point in the history
  • Loading branch information
aopoltorzhicky committed Nov 6, 2023
1 parent 5969eec commit a668136
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 6 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ API_PORT=9876
API_RATE_LIMIT=20
API_PROMETHEUS_ENABLED=false
API_REQUEST_TIMEOUT=10
API_WEBSOCKET_ENABLED=true
SENTRY_DSN=<TODO_INSERT_SENTRY_DSN>
CELENIUM_ENV=production
1 change: 1 addition & 0 deletions build/dipdup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ api:
request_timeout: ${API_REQUEST_TIMEOUT:-30}
blob_receiver: dal_api
sentry_dsn: ${SENTRY_DSN}
websocket: ${API_WEBSOCKET_ENABLED:-true}

environment: ${CELENIUM_ENV:-production}

Expand Down
119 changes: 119 additions & 0 deletions cmd/api/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package cache

import (
"net/http"
"sync"

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
)

type Cache struct {
maxEntitiesCount int

m map[string][]byte
queue []string
mx *sync.RWMutex
}

type Config struct {
MaxEntitiesCount int
}

func NewCache(cfg Config) *Cache {
return &Cache{
maxEntitiesCount: cfg.MaxEntitiesCount,
m: make(map[string][]byte),
queue: make([]string, cfg.MaxEntitiesCount),
mx: new(sync.RWMutex),
}
}

func (c *Cache) Get(key string) ([]byte, bool) {
c.mx.RLock()
data, ok := c.m[key]
c.mx.RUnlock()
return data, ok
}

func (c *Cache) Set(key string, data []byte) {
c.mx.Lock()
queueIdx := len(c.m)
c.m[key] = data
if queueIdx == c.maxEntitiesCount {
keyForRemove := c.queue[queueIdx-1]
c.queue = append([]string{key}, c.queue[:queueIdx-1]...)
delete(c.m, keyForRemove)
} else {
c.queue[queueIdx] = key
}
c.mx.Unlock()
}

func (c *Cache) Clear() {
c.mx.Lock()
for key := range c.m {
delete(c.m, key)
}
c.queue = make([]string, c.maxEntitiesCount)
c.mx.Unlock()
}

type CacheMiddleware struct {
cache *Cache
}

func Middleware(cache *Cache) echo.MiddlewareFunc {
mdlwr := CacheMiddleware{
cache: cache,
}
return mdlwr.Handler
}

func (m *CacheMiddleware) isCacheable(req *http.Request) bool {
return req.Method == http.MethodGet
}

func (m *CacheMiddleware) Handler(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !m.isCacheable(c.Request()) {
return next(c)
}
path := c.Request().URL.String()

if data, ok := m.cache.Get(path); ok {
entry := new(CacheEntry)
if err := entry.Decode(data); err != nil {
return err
}
return entry.Replay(c.Response())
}

recorder := NewResponseRecorder(c.Response().Writer)
c.Response().Writer = recorder

if err := next(c); err != nil {
return err
}
return m.cacheResult(path, recorder)
}
}

func (m *CacheMiddleware) cacheResult(key string, r *ResponseRecorder) error {
result := r.Result()
if !m.isStatusCacheable(result) {
return nil
}

data, err := result.Encode()
if err != nil {
return errors.Wrap(err, "unable to read recorded response")
}

m.cache.Set(key, data)
return nil
}

func (m *CacheMiddleware) isStatusCacheable(e *CacheEntry) bool {
return e.StatusCode == http.StatusOK || e.StatusCode == http.StatusNoContent
}
59 changes: 59 additions & 0 deletions cmd/api/cache/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package cache

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
)

func TestCache_SetGet(t *testing.T) {
t.Run("set and get key from cache", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 2})
c.Set("test", []byte{0, 1, 2, 3})

got, ok := c.Get("test")
require.True(t, ok)
require.Equal(t, []byte{0, 1, 2, 3}, got)

_, exists := c.Get("unknown")
require.False(t, exists)
})

t.Run("overflow set queue", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 2})
for i := 0; i < 100; i++ {
c.Set(fmt.Sprintf("%d", i), []byte{byte(i)})
}

require.Len(t, c.queue, 2)
require.Len(t, c.m, 2)

got, ok := c.Get("99")
require.True(t, ok)
require.Equal(t, []byte{99}, got)

got1, ok1 := c.Get("98")
require.True(t, ok1)
require.Equal(t, []byte{98}, got1)

_, exists := c.Get("0")
require.False(t, exists)
})
}

func TestCache_Clear(t *testing.T) {
t.Run("set and get key from cache", func(t *testing.T) {
c := NewCache(Config{MaxEntitiesCount: 100})
for i := 0; i < 100; i++ {
c.Set(fmt.Sprintf("%d", i), []byte{byte(i)})
}
c.Clear()

require.Len(t, c.queue, 100)
for i := 0; i < 100; i++ {
require.EqualValues(t, c.queue[i], "")
}
require.Len(t, c.m, 0)
})
}
101 changes: 101 additions & 0 deletions cmd/api/cache/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package cache

import (
"bytes"
"encoding/gob"
"net/http"
)

type ResponseRecorder struct {
http.ResponseWriter

status int
body bytes.Buffer
headers http.Header
headerCopied bool
}

func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
ResponseWriter: w,
headers: make(http.Header),
}
}

func (w *ResponseRecorder) Write(b []byte) (int, error) {
w.copyHeaders()
i, err := w.ResponseWriter.Write(b)
if err != nil {
return i, err
}

return w.body.Write(b[:i])
}

func (r *ResponseRecorder) copyHeaders() {
if r.headerCopied {
return
}

r.headerCopied = true
copyHeaders(r.ResponseWriter.Header(), r.headers)
}

func (w *ResponseRecorder) WriteHeader(statusCode int) {
w.copyHeaders()

w.status = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}

func (r *ResponseRecorder) Result() *CacheEntry {
r.copyHeaders()

return &CacheEntry{
Header: r.headers,
StatusCode: r.status,
Body: r.body.Bytes(),
}
}

func copyHeaders(src, dst http.Header) {
for k, v := range src {
for _, v := range v {
dst.Set(k, v)
}
}
}

type CacheEntry struct {
Header http.Header
StatusCode int
Body []byte
}

func (c *CacheEntry) Encode() ([]byte, error) {
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(c); err != nil {
return nil, err
}

return buf.Bytes(), nil
}

func (c *CacheEntry) Decode(b []byte) error {
dec := gob.NewDecoder(bytes.NewReader(b))
return dec.Decode(c)
}

func (c *CacheEntry) Replay(w http.ResponseWriter) error {
copyHeaders(c.Header, w.Header())
if c.StatusCode != 0 {
w.WriteHeader(c.StatusCode)
}

if len(c.Body) == 0 {
return nil
}

_, err := w.Write(c.Body)
return err
}
1 change: 1 addition & 0 deletions cmd/api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ type ApiConfig struct {
RequestTimeout int `validate:"omitempty,min=1" yaml:"request_timeout"`
BlobReceiver string `validate:"required" yaml:"blob_receiver"`
SentryDsn string `validate:"omitempty" yaml:"sentry_dsn"`
Websocket bool `validate:"omitempty" yaml:"websocket"`
}
21 changes: 17 additions & 4 deletions cmd/api/handler/websocket/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type Channel[I, M any] struct {
filters Filterable[M]
repo identifiable[I]

eventHandler func(ctx context.Context, event M) error

g workerpool.Group
}

Expand Down Expand Up @@ -102,10 +104,6 @@ func (channel *Channel[I, M]) waitMessage(ctx context.Context) {
continue
}

if channel.clients.Len() == 0 {
continue
}

if err := channel.processMessage(ctx, msg); err != nil {
log.Err(err).
Str("msg", msg.Channel).
Expand All @@ -128,6 +126,14 @@ func (channel *Channel[I, M]) processMessage(ctx context.Context, msg *pq.Notifi
return errors.Wrap(err, "processing channel message")
}

if err := channel.onEvent(ctx, notification); err != nil {
return err
}

if channel.clients.Len() == 0 {
return nil
}

if err := channel.clients.Range(func(_ uint64, value client) (error, bool) {
if channel.filters.Filter(value, notification) {
value.Notify(notification)
Expand All @@ -150,3 +156,10 @@ func (channel *Channel[I, M]) Close() error {
}
return nil
}

func (channel *Channel[I, M]) onEvent(ctx context.Context, event M) error {
if channel.eventHandler != nil {
return channel.eventHandler(ctx, event)
}
return nil
}
7 changes: 7 additions & 0 deletions cmd/api/handler/websocket/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Manager struct {
head *Channel[storage.Block, *responses.Block]
tx *Channel[storage.Tx, *responses.Tx]
factory storage.ListenerFactory

onBlockReceived func(ctx context.Context, block *responses.Block) error
}

func NewManager(factory storage.ListenerFactory, blockRepo storage.IBlock, txRepo storage.ITx) *Manager {
Expand Down Expand Up @@ -92,6 +94,7 @@ func (manager *Manager) Handle(c echo.Context) error {
}

func (manager *Manager) Start(ctx context.Context) {
manager.head.eventHandler = manager.onBlockReceived
manager.head.Start(ctx, manager.factory)
manager.tx.Start(ctx, manager.factory)
}
Expand Down Expand Up @@ -136,3 +139,7 @@ func (manager *Manager) RemoveClientFromChannel(channel string, client *Client)
log.Error().Str("channel", channel).Msg("unknown channel name")
}
}

func (manager *Manager) SetOnBlockReceived(handler func(ctx context.Context, block *responses.Block) error) {
manager.onBlockReceived = handler
}
Loading

0 comments on commit a668136

Please sign in to comment.