Skip to content

Commit 71afc27

Browse files
committed
feat: implement refresh token
1 parent 8f5b4ca commit 71afc27

File tree

8 files changed

+176
-7
lines changed

8 files changed

+176
-7
lines changed

docs/api.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,22 @@ Logs in a user, returns **JWT access_token** in JSON and sets a **refresh_token*
7878
**Errors**:
7979

8080
- `401 Unauthorized` → invalid credentials
81+
82+
### 3. Refresh Token
83+
84+
`POST /auth/refresh`
85+
86+
Uses the **refresh token cookie** to issue a new access token.
87+
88+
**Response** `200 OK`:
89+
90+
```json
91+
{
92+
"access_token": "<NEW_JWT_TOKEN>",
93+
"expires_in": 3600
94+
}
95+
```
96+
97+
**Errors**:
98+
99+
- `401 Unauthorized` → invalid_request/invalid_grant

internal/db/token_repository.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77

8+
"github.com/jackc/pgx/v5"
89
"github.com/jackc/pgx/v5/pgconn"
910
"github.com/jackc/pgx/v5/pgxpool"
1011
"github.com/raphico/go-device-telemetry-api/internal/domain/token"
@@ -54,3 +55,71 @@ func (r *TokenRepository) Create(ctx context.Context, t *token.Token) error {
5455

5556
return nil
5657
}
58+
59+
func (r *TokenRepository) FindValidTokenByHash(ctx context.Context, hash []byte, scope string) (*token.Token, error) {
60+
t := &token.Token{}
61+
query := `
62+
SELECT id, token_hash, user_id, scope, revoked, expires_at, last_used_at, created_at
63+
FROM tokens
64+
WHERE token_hash = $1
65+
AND scope = $2
66+
AND revoked = false
67+
AND expires_at > now()
68+
`
69+
70+
err := r.db.QueryRow(ctx, query, hash, scope).Scan(
71+
&t.ID,
72+
&t.Hash,
73+
&t.UserID,
74+
&t.Scope,
75+
&t.Revoked,
76+
&t.ExpiresAt,
77+
&t.LastUsedAt,
78+
&t.CreatedAt,
79+
)
80+
81+
if err != nil {
82+
if errors.Is(err, pgx.ErrNoRows) {
83+
return nil, token.ErrTokenNotFound
84+
}
85+
86+
return nil, fmt.Errorf("failed to find token: %w", err)
87+
}
88+
89+
return t, nil
90+
}
91+
92+
func (r *TokenRepository) Revoke(ctx context.Context, scope string, hash []byte) error {
93+
query := `
94+
UPDATE tokens
95+
SET revoked = true
96+
WHERE token_hash = $1
97+
AND scope = $2
98+
AND revoked = false
99+
`
100+
101+
tag, err := r.db.Exec(ctx, query, hash, scope)
102+
if err != nil {
103+
return fmt.Errorf("failed to revoke token: %w", err)
104+
}
105+
106+
if tag.RowsAffected() == 0 {
107+
return token.ErrTokenNotFound
108+
}
109+
110+
return nil
111+
}
112+
113+
func (r *TokenRepository) UpdateLastUsed(ctx context.Context, hash []byte) error {
114+
query := `UPDATE tokens SET last_used_at = now() WHERE token_hash = $1 AND revoked = false`
115+
tag, err := r.db.Exec(ctx, query, hash)
116+
if err != nil {
117+
return fmt.Errorf("failed to revoke token: %w", err)
118+
}
119+
120+
if tag.RowsAffected() == 0 {
121+
return token.ErrTokenNotFound
122+
}
123+
124+
return nil
125+
}

internal/domain/token/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "errors"
55
var (
66
ErrTokenGenerationFailed = errors.New("failed to generate token")
77
ErrUserNotFound = errors.New("user not found")
8+
ErrTokenNotFound = errors.New("token not found")
89
ErrTokenAlreadyExists = errors.New("token already exists")
910
ErrInvalidToken = errors.New("invalid token")
1011
ErrExpiredToken = errors.New("token expired")
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package token
22

3-
import "context"
3+
import (
4+
"context"
5+
)
46

57
type Repository interface {
68
Create(ctx context.Context, t *Token) error
9+
FindValidTokenByHash(ctx context.Context, hash []byte, scope string) (*Token, error)
10+
Revoke(ctx context.Context, scope string, hash []byte) error
11+
UpdateLastUsed(ctx context.Context, hash []byte) error
712
}

internal/domain/token/service.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,33 @@ func (s *Service) CreateRefreshToken(ctx context.Context, userId user.UserID) (*
6464

6565
return token, nil
6666
}
67+
68+
func (s *Service) RotateTokens(ctx context.Context, refreshTok string) (string, *Token, error) {
69+
hash := HashPlaintext(refreshTok)
70+
71+
tokenRecord, err := s.repo.FindValidTokenByHash(ctx, hash, "auth")
72+
if err != nil {
73+
fmt.Println(err)
74+
return "", nil, err
75+
}
76+
77+
if err := s.repo.UpdateLastUsed(ctx, hash); err != nil {
78+
return "", nil, err
79+
}
80+
81+
if err := s.repo.Revoke(ctx, "auth", tokenRecord.Hash); err != nil {
82+
return "", nil, err
83+
}
84+
85+
accessToken, err := s.GenerateAccessToken(tokenRecord.UserID)
86+
if err != nil {
87+
return "", nil, err
88+
}
89+
90+
token, err := s.CreateRefreshToken(ctx, tokenRecord.UserID)
91+
if err != nil {
92+
return "", nil, err
93+
}
94+
95+
return accessToken, token, nil
96+
}

internal/domain/token/token.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ type Token struct {
2121
Scope string
2222
Revoked bool
2323
ExpiresAt time.Time
24-
LastUsedAt time.Time
24+
LastUsedAt *time.Time
25+
CreatedAt time.Time
2526
}
2627

2728
func NewToken(userId user.UserID, ttl time.Duration, scope string) (*Token, error) {
@@ -32,7 +33,7 @@ func NewToken(userId user.UserID, ttl time.Duration, scope string) (*Token, erro
3233

3334
plaintext := base64.RawURLEncoding.EncodeToString(b)
3435

35-
hash := sha256.Sum256([]byte(plaintext))
36+
hash := HashPlaintext(plaintext)
3637

3738
return &Token{
3839
UserID: userId,
@@ -42,3 +43,8 @@ func NewToken(userId user.UserID, ttl time.Duration, scope string) (*Token, erro
4243
Scope: scope,
4344
}, nil
4445
}
46+
47+
func HashPlaintext(plaintext string) []byte {
48+
hash := sha256.Sum256([]byte(plaintext))
49+
return hash[:]
50+
}

internal/transport/http/router.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func NewRouter(log *logger.Logger, userHandler *UserHandler) http.Handler {
3030
r.Route("/auth", func(r chi.Router) {
3131
r.Post("/register", userHandler.RegisterUser)
3232
r.Post("/login", userHandler.LoginUser)
33+
r.Post("/refresh", userHandler.RefreshAccessToken)
3334
})
3435
})
3536

internal/transport/http/user_handler.go

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ type loginUserRequest struct {
9393
Password string `json:"password"`
9494
}
9595

96-
type loginUserResponse struct {
96+
type tokenResponse struct {
9797
AccessToken string `json:"access_token"`
9898
ExpiresIn int `json:"expires_in"`
9999
}
@@ -113,13 +113,13 @@ func (h *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) {
113113

114114
accessToken, err := h.tokenService.GenerateAccessToken(user.ID)
115115
if err != nil {
116-
WriteJSONError(w, http.StatusInternalServerError, "server_error", "failed to generate access token")
116+
WriteJSONError(w, http.StatusInternalServerError, "internal_error", "failed to generate access token")
117117
return
118118
}
119119

120120
refreshToken, err := h.tokenService.CreateRefreshToken(r.Context(), user.ID)
121121
if err != nil {
122-
WriteJSONError(w, http.StatusInternalServerError, "server_error", "failed to create refresh token")
122+
WriteJSONError(w, http.StatusInternalServerError, "internal_error", "failed to create refresh token")
123123
return
124124
}
125125

@@ -138,7 +138,45 @@ func (h *UserHandler) LoginUser(w http.ResponseWriter, r *http.Request) {
138138
}
139139
http.SetCookie(w, cookie)
140140

141-
resp := &loginUserResponse{
141+
resp := tokenResponse{
142+
AccessToken: accessToken,
143+
ExpiresIn: int(h.cfg.AccessTokenTTL.Seconds()),
144+
}
145+
146+
WriteJSON(w, http.StatusOK, resp)
147+
}
148+
149+
func (h *UserHandler) RefreshAccessToken(w http.ResponseWriter, r *http.Request) {
150+
cookie, err := r.Cookie("refresh_token")
151+
if err != nil {
152+
WriteJSONError(w, http.StatusUnauthorized, "invalid_request", "refresh token missing")
153+
return
154+
}
155+
156+
refreshTok := cookie.Value
157+
158+
accessToken, refreshToken, err := h.tokenService.RotateTokens(r.Context(), refreshTok)
159+
if err != nil {
160+
WriteJSONError(w, http.StatusUnauthorized, "invalid_grant", "invalid or expired refresh token")
161+
return
162+
}
163+
164+
cookie = &http.Cookie{
165+
Name: "refresh_token",
166+
Value: refreshToken.Plaintext,
167+
Path: "/",
168+
HttpOnly: true,
169+
Expires: refreshToken.ExpiresAt,
170+
SameSite: http.SameSiteLaxMode,
171+
}
172+
if h.cfg.Env == "production" {
173+
cookie.Secure = true
174+
} else {
175+
cookie.Secure = false
176+
}
177+
http.SetCookie(w, cookie)
178+
179+
resp := tokenResponse{
142180
AccessToken: accessToken,
143181
ExpiresIn: int(h.cfg.AccessTokenTTL.Seconds()),
144182
}

0 commit comments

Comments
 (0)