diff --git a/middleware_test.go b/middleware_test.go index c2cb15a..746909b 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -67,7 +67,6 @@ func TestMiddlewarePlain(t *testing.T) { } func TestMiddlewareContext(t *testing.T) { - t.Skip() expect := "3" r := NewRouter() _ = r.Add(&Sample{}, contextMiddleware(0), contextMiddleware(1), contextMiddleware(2)) @@ -88,7 +87,6 @@ func TestMiddlewareContext(t *testing.T) { } func TestMiddlewareMixed(t *testing.T) { - t.Skip() expect := "6" r := NewRouter() diff --git a/routes.go b/routes.go index 1c08b5b..dadddd2 100644 --- a/routes.go +++ b/routes.go @@ -279,61 +279,89 @@ func splitRoutes(routeStr string) (*route, error) { return nil, ErrRouteStringFormat } +type middlewareTyp int + +const ( + plainMiddleware middlewareTyp = iota + ctxMiddleware +) + +type middleware struct { + typ middlewareTyp + value interface{} +} + +func (m *middleware) ToHandler(ctx *Context) func(http.Handler) http.Handler { + if m.typ == plainMiddleware { + return m.value.(func(http.Handler) http.Handler) + } + fn := m.value.(func(*Context) error) + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := fn(ctx) + if err != nil { + return + } + h.ServeHTTP(w, r) + }) + } +} + // add registers controller ctrl, using activeRoute. If middlewares are provided, utron uses // alice package to chain middlewares. func (r *Router) add(activeRoute *route, ctrl Controller, middlewares ...interface{}) error { - chain := alice.New() // alice on chains + var m []*middleware if len(middlewares) > 0 { - var m []alice.Constructor for _, v := range middlewares { switch v.(type) { case func(http.Handler) http.Handler: - m = append(m, v.(func(http.Handler) http.Handler)) + m = append(m, &middleware{ + typ: plainMiddleware, + value: v, + }) case func(*Context) error: - - // wrap func(*Context)error to a func(http.Handler)http.Handler - // - //TODO put this into a separate function? - ctxMiddleware := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ctx := NewContext(w, req) - r.prepareContext(ctx) - err := v.(func(*Context) error)(ctx) - if err != nil { - cerr := ctx.Commit() - if cerr != nil { - logThis.Errors(req.URL.Path, " ", cerr.Error()) - } - return - } - h.ServeHTTP(ctx.Response(), ctx.Request()) - }) - } - - m = append(m, ctxMiddleware) + m = append(m, &middleware{ + typ: ctxMiddleware, + value: v, + }) default: return fmt.Errorf("unsupported middleware %v", v) } } - chain = alice.New(m...) } - // register methods if any if len(activeRoute.methods) > 0 { r.HandleFunc(activeRoute.pattern, func(w http.ResponseWriter, req *http.Request) { - chain.ThenFunc(r.wrapController(activeRoute.fn, ctrl)).ServeHTTP(w, req) + ctx := NewContext(w, req) + r.prepareContext(ctx) + chain := chainMiddleware(ctx, m...) + chain.ThenFunc(r.wrapController(ctx, activeRoute.fn, ctrl)).ServeHTTP(w, req) }).Methods(activeRoute.methods...) return nil } - r.HandleFunc(activeRoute.pattern, func(w http.ResponseWriter, req *http.Request) { - chain.ThenFunc(r.wrapController(activeRoute.fn, ctrl)).ServeHTTP(w, req) + ctx := NewContext(w, req) + r.prepareContext(ctx) + chain := chainMiddleware(ctx, m...) + chain.ThenFunc(r.wrapController(ctx, activeRoute.fn, ctrl)).ServeHTTP(w, req) }) return nil } +func chainMiddleware(ctx *Context, wares ...*middleware) alice.Chain { + if len(wares) > 0 { + var m []alice.Constructor + for _, v := range wares { + m = append(m, v.ToHandler(ctx)) + } + return alice.New(m...) + } + return alice.New() + +} + // prepareContext sets view,config and model on the ctx. func (r *Router) prepareContext(ctx *Context) { if r.app != nil { @@ -350,9 +378,7 @@ func (r *Router) prepareContext(ctx *Context) { } // executes the method fn on Controller ctrl, it sets context. -func (r *Router) handleController(w http.ResponseWriter, req *http.Request, fn string, ctrl Controller) { - ctx := NewContext(w, req) - r.prepareContext(ctx) +func (r *Router) handleController(ctx *Context, fn string, ctrl Controller) { ctrl.New(ctx) // execute the method @@ -371,9 +397,9 @@ func (r *Router) handleController(w http.ResponseWriter, req *http.Request, fn s } // wrapController wraps a controller ctrl with method fn, and returns http.HandleFunc -func (r *Router) wrapController(fn string, ctrl Controller) func(http.ResponseWriter, *http.Request) { +func (r *Router) wrapController(ctx *Context, fn string, ctrl Controller) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, req *http.Request) { - r.handleController(w, req, fn, ctrl) + r.handleController(ctx, fn, ctrl) } }