diff --git a/middleware.go b/middleware.go index 858749f..2fa54cc 100644 --- a/middleware.go +++ b/middleware.go @@ -27,7 +27,7 @@ type ( func WithDevel(isDevel bool) zenrpc.MiddlewareFunc { return func(h zenrpc.InvokeFunc) zenrpc.InvokeFunc { return func(ctx context.Context, method string, params json.RawMessage) zenrpc.Response { - ctx = context.WithValue(ctx, isDevelCtx, isDevel) + ctx = NewIsDevelContext(ctx, isDevel) return h(ctx, method, params) } } @@ -70,6 +70,11 @@ func UserAgentFromContext(ctx context.Context) string { return r } +// NewIsDevelContext creates new context with isDevel flag. +func NewIsDevelContext(ctx context.Context, isDevel bool) context.Context { + return context.WithValue(ctx, isDevelCtx, isDevel) +} + func IsDevelFromContext(ctx context.Context) bool { if isDevel, ok := ctx.Value(isDevelCtx).(bool); ok { return isDevel diff --git a/sql_logger.go b/sql_logger.go index cd9ea41..d77932e 100644 --- a/sql_logger.go +++ b/sql_logger.go @@ -24,7 +24,7 @@ const ( type AllowDebugFunc func(*http.Request) bool -func debugIDFromContext(ctx context.Context) uint64 { +func DebugIDFromContext(ctx context.Context) uint64 { if ctx == nil { return emptyDebugID } @@ -36,6 +36,11 @@ func debugIDFromContext(ctx context.Context) uint64 { return emptyDebugID } +// NewDebugIDContext creates new context with debug ID. +func NewDebugIDContext(ctx context.Context, debugID uint64) context.Context { + return context.WithValue(ctx, debugIDCtx, debugID) +} + // NewSqlGroupContext creates new context with SQL Group for debug SQL logging. func NewSqlGroupContext(ctx context.Context, group string) context.Context { groups, _ := ctx.Value(sqlGroupCtx).(string) @@ -46,6 +51,12 @@ func NewSqlGroupContext(ctx context.Context, group string) context.Context { return context.WithValue(ctx, sqlGroupCtx, groups) } +// SqlGroupFromContext returns sql group from context. +func SqlGroupFromContext(ctx context.Context) string { + r, _ := ctx.Value(sqlGroupCtx).(string) + return r +} + // WithTiming adds timings in JSON-RPC 2.0 Response via `extensions` field (not in spec). Middleware is active // when `isDevel=true` or AllowDebugFunc returns `true`. // `DurationLocal` – total method execution time in ms. @@ -91,7 +102,7 @@ func WithTiming(isDevel bool, allowDebugFunc AllowDebugFunc) zenrpc.MiddlewareFu // `SQL` field is set then `isDevel=true` or AllowDebugFunc(allowDebugFunc, allowSqlDebugFunc) returns `true`. func WithSQLLogger(db *pg.DB, isDevel bool, allowDebugFunc, allowSqlDebugFunc AllowDebugFunc) zenrpc.MiddlewareFunc { // init sql logger - ql := newSqlQueryLogger() + ql := NewSqlQueryLogger() db.AddQueryHook(ql) return func(h zenrpc.InvokeFunc) zenrpc.InvokeFunc { @@ -111,7 +122,7 @@ func WithSQLLogger(db *pg.DB, isDevel bool, allowDebugFunc, allowSqlDebugFunc Al } debugID := ql.NextID() - ctx = context.WithValue(ctx, debugIDCtx, debugID) + ctx = NewDebugIDContext(ctx, debugID) ql.Push(debugID) resp = h(ctx, method, params) @@ -159,7 +170,7 @@ func (d Duration) MarshalJSON() (b []byte, err error) { return []byte(fmt.Sprintf(`"%s"`, d.String())), nil } -func newSqlQueryLogger() *sqlQueryLogger { +func NewSqlQueryLogger() *sqlQueryLogger { return &sqlQueryLogger{ data: make(map[uint64][]sqlQuery), dataMu: &sync.Mutex{}, @@ -171,7 +182,7 @@ func (ql sqlQueryLogger) BeforeQuery(ctx context.Context, event *pg.QueryEvent) event.Stash = make(map[interface{}]interface{}) } - if debugIDFromContext(ctx) != emptyDebugID { + if DebugIDFromContext(ctx) != emptyDebugID { event.Stash[eventStartedAt] = time.Now() } @@ -179,7 +190,7 @@ func (ql sqlQueryLogger) BeforeQuery(ctx context.Context, event *pg.QueryEvent) } func (ql sqlQueryLogger) AfterQuery(ctx context.Context, event *pg.QueryEvent) error { - debugID := debugIDFromContext(ctx) + debugID := DebugIDFromContext(ctx) if debugID == emptyDebugID { return nil } @@ -200,8 +211,7 @@ func (ql sqlQueryLogger) AfterQuery(ctx context.Context, event *pg.QueryEvent) e } } - sqlGroup, _ := ctx.Value(sqlGroupCtx).(string) - sq.Group = strings.Trim(sqlGroup, ">") + sq.Group = strings.Trim(SqlGroupFromContext(ctx), ">") ql.Store(debugID, sq)