Skip to content

Commit

Permalink
Merge pull request #3 from sas1024/master
Browse files Browse the repository at this point in the history
Make sql query logger public, add context getters and setters
  • Loading branch information
sas1024 authored Sep 6, 2022
2 parents 09333e6 + ab324f0 commit 43d55e7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
7 changes: 6 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type (
func WithDevel(isDevel bool) zenrpc.MiddlewareFunc {
return func(h zenrpc.InvokeFunc) zenrpc.InvokeFunc {
return func(ctx context.Context, method string, params json.RawMessage) zenrpc.Response {
ctx = context.WithValue(ctx, isDevelCtx, isDevel)
ctx = NewIsDevelContext(ctx, isDevel)
return h(ctx, method, params)
}
}
Expand Down Expand Up @@ -70,6 +70,11 @@ func UserAgentFromContext(ctx context.Context) string {
return r
}

// NewIsDevelContext creates new context with isDevel flag.
func NewIsDevelContext(ctx context.Context, isDevel bool) context.Context {
return context.WithValue(ctx, isDevelCtx, isDevel)
}

func IsDevelFromContext(ctx context.Context) bool {
if isDevel, ok := ctx.Value(isDevelCtx).(bool); ok {
return isDevel
Expand Down
26 changes: 18 additions & 8 deletions sql_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (

type AllowDebugFunc func(*http.Request) bool

func debugIDFromContext(ctx context.Context) uint64 {
func DebugIDFromContext(ctx context.Context) uint64 {
if ctx == nil {
return emptyDebugID
}
Expand All @@ -36,6 +36,11 @@ func debugIDFromContext(ctx context.Context) uint64 {
return emptyDebugID
}

// NewDebugIDContext creates new context with debug ID.
func NewDebugIDContext(ctx context.Context, debugID uint64) context.Context {
return context.WithValue(ctx, debugIDCtx, debugID)
}

// NewSqlGroupContext creates new context with SQL Group for debug SQL logging.
func NewSqlGroupContext(ctx context.Context, group string) context.Context {
groups, _ := ctx.Value(sqlGroupCtx).(string)
Expand All @@ -46,6 +51,12 @@ func NewSqlGroupContext(ctx context.Context, group string) context.Context {
return context.WithValue(ctx, sqlGroupCtx, groups)
}

// SqlGroupFromContext returns sql group from context.
func SqlGroupFromContext(ctx context.Context) string {
r, _ := ctx.Value(sqlGroupCtx).(string)
return r
}

// WithTiming adds timings in JSON-RPC 2.0 Response via `extensions` field (not in spec). Middleware is active
// when `isDevel=true` or AllowDebugFunc returns `true`.
// `DurationLocal` – total method execution time in ms.
Expand Down Expand Up @@ -91,7 +102,7 @@ func WithTiming(isDevel bool, allowDebugFunc AllowDebugFunc) zenrpc.MiddlewareFu
// `SQL` field is set then `isDevel=true` or AllowDebugFunc(allowDebugFunc, allowSqlDebugFunc) returns `true`.
func WithSQLLogger(db *pg.DB, isDevel bool, allowDebugFunc, allowSqlDebugFunc AllowDebugFunc) zenrpc.MiddlewareFunc {
// init sql logger
ql := newSqlQueryLogger()
ql := NewSqlQueryLogger()
db.AddQueryHook(ql)

return func(h zenrpc.InvokeFunc) zenrpc.InvokeFunc {
Expand All @@ -111,7 +122,7 @@ func WithSQLLogger(db *pg.DB, isDevel bool, allowDebugFunc, allowSqlDebugFunc Al
}

debugID := ql.NextID()
ctx = context.WithValue(ctx, debugIDCtx, debugID)
ctx = NewDebugIDContext(ctx, debugID)
ql.Push(debugID)

resp = h(ctx, method, params)
Expand Down Expand Up @@ -159,7 +170,7 @@ func (d Duration) MarshalJSON() (b []byte, err error) {
return []byte(fmt.Sprintf(`"%s"`, d.String())), nil
}

func newSqlQueryLogger() *sqlQueryLogger {
func NewSqlQueryLogger() *sqlQueryLogger {
return &sqlQueryLogger{
data: make(map[uint64][]sqlQuery),
dataMu: &sync.Mutex{},
Expand All @@ -171,15 +182,15 @@ func (ql sqlQueryLogger) BeforeQuery(ctx context.Context, event *pg.QueryEvent)
event.Stash = make(map[interface{}]interface{})
}

if debugIDFromContext(ctx) != emptyDebugID {
if DebugIDFromContext(ctx) != emptyDebugID {
event.Stash[eventStartedAt] = time.Now()
}

return ctx, nil
}

func (ql sqlQueryLogger) AfterQuery(ctx context.Context, event *pg.QueryEvent) error {
debugID := debugIDFromContext(ctx)
debugID := DebugIDFromContext(ctx)
if debugID == emptyDebugID {
return nil
}
Expand All @@ -200,8 +211,7 @@ func (ql sqlQueryLogger) AfterQuery(ctx context.Context, event *pg.QueryEvent) e
}
}

sqlGroup, _ := ctx.Value(sqlGroupCtx).(string)
sq.Group = strings.Trim(sqlGroup, ">")
sq.Group = strings.Trim(SqlGroupFromContext(ctx), ">")

ql.Store(debugID, sq)

Expand Down

0 comments on commit 43d55e7

Please sign in to comment.