Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion controller/codex_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
return
}

channelProxy := ""
if channelID > 0 {
ch, err := model.GetChannelById(channelID, false)
if err != nil {
Expand All @@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
return
}
channelProxy = ch.GetSetting().Proxy
}

session := sessions.Default(c)
Expand All @@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
defer cancel()

tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
if err != nil {
common.SysError("failed to exchange codex authorization code: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
Expand Down
5 changes: 2 additions & 3 deletions controller/codex_usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controller

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) {
refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer refreshCancel()

res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
if refreshErr == nil {
oauthKey.AccessToken = res.AccessToken
oauthKey.RefreshToken = res.RefreshToken
Expand Down Expand Up @@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) {
}

var payload any
if json.Unmarshal(body, &payload) != nil {
if common.Unmarshal(body, &payload) != nil {
payload = string(body)
}

Expand Down
2 changes: 1 addition & 1 deletion service/codex_credential_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
if err != nil {
return nil, nil, err
}
Expand Down
37 changes: 33 additions & 4 deletions service/codex_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"net/url"
"strings"
"time"

"github.com/QuantumNous/new-api/common"
)

const (
Expand All @@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct {
}

func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
client := &http.Client{Timeout: defaultHTTPTimeout}
return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "")
}

func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) {
client, err := getCodexOAuthHTTPClient(proxyURL)
if err != nil {
return nil, err
}
return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
}

func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
client := &http.Client{Timeout: defaultHTTPTimeout}
return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "")
}

func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) {
client, err := getCodexOAuthHTTPClient(proxyURL)
if err != nil {
return nil, err
}
return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
}

Expand Down Expand Up @@ -104,7 +120,7 @@ func refreshCodexOAuthToken(
ExpiresIn int `json:"expires_in"`
}

if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
if err := common.DecodeJson(resp.Body, &payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
Expand Down Expand Up @@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode(
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
if err := common.DecodeJson(resp.Body, &payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
Expand All @@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode(
}, nil
}

func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) {
baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL))
if err != nil {
return nil, err
}
if baseClient == nil {
return &http.Client{Timeout: defaultHTTPTimeout}, nil
}
clientCopy := *baseClient
clientCopy.Timeout = defaultHTTPTimeout
return &clientCopy, nil
}

func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
u, err := url.Parse(codexOAuthAuthorizeURL)
if err != nil {
Expand Down