Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions internal/oauth/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oauth

import (
"context"
"database/sql"
"time"

Expand All @@ -12,9 +13,15 @@ type User struct {
DID string
}

// SessionStore defines the session operations needed by the middleware.
type SessionStore interface {
GetSessionByID(ctx context.Context, id string) (*OAuthSession, error)
DeleteSession(ctx context.Context, id string) error
}

// SessionMiddleware creates middleware that reads the session cookie
// and adds the user to the context if the session is valid
func SessionMiddleware(storage *Storage) echo.MiddlewareFunc {
func SessionMiddleware(storage SessionStore) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Try to get session cookie
Expand All @@ -38,8 +45,10 @@ func SessionMiddleware(storage *Storage) echo.MiddlewareFunc {

// Check if session is expired
if session.ExpiresAt.Before(time.Now()) {
// Expired session - continue without user
// TODO: Consider cleaning up expired session here
// Clean up expired session from database
if err := storage.DeleteSession(c.Request().Context(), cookie.Value); err != nil {
c.Logger().Errorf("Failed to delete expired session: %v", err)
}
return next(c)
}

Expand Down
12 changes: 10 additions & 2 deletions internal/oauth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,25 @@ func TestSessionMiddleware(t *testing.T) {
assert.Equal(t, "did:plc:test123", capturedUser.DID)
})

t.Run("expired session - sets nil user in context", func(t *testing.T) {
t.Run("expired session - sets nil user in context and deletes session", func(t *testing.T) {
// Create an expired session
e := echo.New()
setupReq := httptest.NewRequest(http.MethodGet, "/", nil)
setupRec := httptest.NewRecorder()
setupCtx := e.NewContext(setupReq, setupRec)

session := OAuthSession{
ID: "expired-session",
ID: "expired-session-cleanup",
DID: "did:plc:expired",
ExpiresAt: time.Now().Add(-1 * time.Hour),
}
err := storage.CreateSession(setupCtx.Request().Context(), session)
require.NoError(t, err)

// Verify session exists before middleware call
_, err = storage.GetSessionByID(setupCtx.Request().Context(), session.ID)
require.NoError(t, err, "Session should exist before middleware call")

// Make request with expired session cookie
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{
Expand All @@ -128,6 +132,10 @@ func TestSessionMiddleware(t *testing.T) {
err = handler(c)
require.NoError(t, err)
assert.Nil(t, capturedUser)

// Verify session was deleted from database
_, err = storage.GetSessionByID(setupCtx.Request().Context(), session.ID)
assert.Error(t, err, "Expired session should be deleted from database")
})
}

Expand Down
137 changes: 137 additions & 0 deletions internal/oauth/middleware_unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package oauth

import (
"context"
"database/sql"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type stubSessionStore struct {
sessions map[string]*OAuthSession
deleteErr error
deleteCalls []string
}

func (s *stubSessionStore) GetSessionByID(ctx context.Context, id string) (*OAuthSession, error) {
session, ok := s.sessions[id]
if !ok {
return nil, sql.ErrNoRows
}
return session, nil
}

func (s *stubSessionStore) DeleteSession(ctx context.Context, id string) error {
s.deleteCalls = append(s.deleteCalls, id)
if s.deleteErr != nil {
return s.deleteErr
}
delete(s.sessions, id)
return nil
}

func TestSessionMiddlewareExpiredSessionDeletes(t *testing.T) {
store := &stubSessionStore{
sessions: map[string]*OAuthSession{
"expired-session": {
ID: "expired-session",
DID: "did:plc:expired",
ExpiresAt: time.Now().Add(-1 * time.Minute),
},
},
}

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "session", Value: "expired-session"})
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

var capturedUser *User
nextCalled := false
handler := SessionMiddleware(store)(func(c echo.Context) error {
nextCalled = true
capturedUser = GetUser(c)
return c.String(http.StatusOK, "ok")
})

err := handler(c)
require.NoError(t, err)
assert.True(t, nextCalled)
assert.Nil(t, capturedUser)
assert.Equal(t, []string{"expired-session"}, store.deleteCalls)
_, exists := store.sessions["expired-session"]
assert.False(t, exists)
}

func TestSessionMiddlewareDeleteErrorDoesNotBlock(t *testing.T) {
store := &stubSessionStore{
sessions: map[string]*OAuthSession{
"expired-session": {
ID: "expired-session",
DID: "did:plc:expired",
ExpiresAt: time.Now().Add(-1 * time.Minute),
},
},
deleteErr: errors.New("delete failed"),
}

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "session", Value: "expired-session"})
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

var capturedUser *User
nextCalled := false
handler := SessionMiddleware(store)(func(c echo.Context) error {
nextCalled = true
capturedUser = GetUser(c)
return c.String(http.StatusOK, "ok")
})

err := handler(c)
require.NoError(t, err)
assert.True(t, nextCalled)
assert.Nil(t, capturedUser)
assert.Equal(t, []string{"expired-session"}, store.deleteCalls)
_, exists := store.sessions["expired-session"]
assert.True(t, exists)
}

func TestSessionMiddlewareValidSessionDoesNotDelete(t *testing.T) {
store := &stubSessionStore{
sessions: map[string]*OAuthSession{
"valid-session": {
ID: "valid-session",
DID: "did:plc:valid",
ExpiresAt: time.Now().Add(1 * time.Hour),
},
},
}

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "session", Value: "valid-session"})
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

var capturedUser *User
handler := SessionMiddleware(store)(func(c echo.Context) error {
capturedUser = GetUser(c)
return c.String(http.StatusOK, "ok")
})

err := handler(c)
require.NoError(t, err)
require.NotNil(t, capturedUser)
assert.Equal(t, "did:plc:valid", capturedUser.DID)
assert.Empty(t, store.deleteCalls)
}