Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 143 additions & 2 deletions src/core/openai/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,128 @@ def _to_int(v: Any) -> int:
return 0


def _looks_like_organization_id(value: Any) -> bool:
"""判断是否为 OpenAI organization/workspace ID。"""
text = str(value or "").strip()
return text.startswith("org-")


def _extract_organization_id_from_items(items: Any) -> str:
"""从 organizations/workspaces 列表中提取 organization_id。"""
if not isinstance(items, list):
return ""
for item in items:
if not isinstance(item, dict):
continue
organization_id = str(item.get("id") or "").strip()
if _looks_like_organization_id(organization_id):
return organization_id
return ""


def _extract_default_organization_id(*payloads: Any) -> str:
"""
从多个 payload 中提取 organization_id。

某些 OAuth 响应会把默认组织混在 workspace/default_workspace 字段里,
这里统一做一层归一化,避免后续把 account_id 错当 workspace_id。
"""
for payload in payloads:
if not isinstance(payload, dict):
continue
for key in (
"organization_id",
"organizationId",
"default_organization_id",
"defaultOrganizationId",
"default_workspace_id",
"defaultWorkspaceId",
"workspace_id",
"workspaceId",
"id",
):
organization_id = str(payload.get(key) or "").strip()
if _looks_like_organization_id(organization_id):
return organization_id

organization_id = _extract_organization_id_from_items(payload.get("organizations"))
if organization_id:
return organization_id

account_payload = payload.get("account")
if isinstance(account_payload, dict):
organization_id = _extract_default_organization_id(account_payload)
if organization_id:
return organization_id

auth_payload = payload.get("https://api.openai.com/auth")
if isinstance(auth_payload, dict):
organization_id = _extract_default_organization_id(auth_payload)
if organization_id:
return organization_id
return ""


def _resolve_workspace_id(
*payloads: Any,
account_id: str = "",
organization_id: str = "",
) -> str:
"""
从多个 payload 中解析 workspace_id,并修复 workspace_id 被 account_id 污染的情况。
"""
normalized_account_id = str(account_id or "").strip()
normalized_organization_id = str(organization_id or "").strip()

def _iter_workspace_candidates(payload: Any):
if not isinstance(payload, dict):
return
for key in (
"workspace_id",
"workspaceId",
"default_workspace_id",
"defaultWorkspaceId",
"id",
):
value = str(payload.get(key) or "").strip()
if value:
yield value

workspace_payload = payload.get("workspace")
if isinstance(workspace_payload, dict):
workspace_id = str(workspace_payload.get("id") or "").strip()
if workspace_id:
yield workspace_id

workspaces = payload.get("workspaces")
if isinstance(workspaces, list):
for item in workspaces:
if not isinstance(item, dict):
continue
workspace_id = str(item.get("id") or "").strip()
if workspace_id:
yield workspace_id

account_payload = payload.get("account")
if isinstance(account_payload, dict):
yield from _iter_workspace_candidates(account_payload)

for payload in payloads:
for candidate in _iter_workspace_candidates(payload):
if (
normalized_organization_id
and normalized_account_id
and candidate == normalized_account_id
and normalized_organization_id != normalized_account_id
):
continue
return candidate

if normalized_organization_id and normalized_organization_id != normalized_account_id:
return normalized_organization_id
return ""


def _post_form(
url: str,
data: Dict[str, str],
Expand Down Expand Up @@ -290,6 +412,14 @@ def submit_callback_url(
email = str(claims.get("email") or "").strip()
auth_claims = claims.get("https://api.openai.com/auth") or {}
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
organization_id = _extract_default_organization_id(token_resp, claims, auth_claims)
workspace_id = _resolve_workspace_id(
token_resp,
claims,
auth_claims,
account_id=account_id,
organization_id=organization_id,
)

now = int(time.time())
expired_rfc3339 = time.strftime(
Expand All @@ -302,6 +432,8 @@ def submit_callback_url(
"access_token": access_token,
"refresh_token": refresh_token,
"account_id": account_id,
"organization_id": organization_id,
"workspace_id": workspace_id,
"last_refresh": now_rfc3339,
"email": email,
"type": "codex",
Expand Down Expand Up @@ -357,14 +489,23 @@ def handle_callback(
return json.loads(result_json)

def extract_account_info(self, id_token: str) -> Dict[str, Any]:
"""从 ID Token 中提取账户信息"""
"""从 ID Token 中提取账户和组织信息"""
claims = _jwt_claims_no_verify(id_token)
email = str(claims.get("email") or "").strip()
auth_claims = claims.get("https://api.openai.com/auth") or {}
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
organization_id = _extract_default_organization_id(claims, auth_claims)
workspace_id = _resolve_workspace_id(
claims,
auth_claims,
account_id=account_id,
organization_id=organization_id,
)

return {
"email": email,
"account_id": account_id,
"organization_id": organization_id,
"workspace_id": workspace_id,
"claims": claims
}
}
28 changes: 12 additions & 16 deletions src/core/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,10 +1309,7 @@ def _complete_token_exchange(self, result: RegistrationResult, require_login_otp
self._log("处理 OAuth 回调,准备把 token 请出来...")
token_info = self._handle_oauth_callback(callback_url)
if token_info:
result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "") or result.access_token
result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "")
self._apply_oauth_token_info(result, token_info)
elif captured:
self._log("OAuth 回调失败,但 session/access 已拿到,继续后续流程", "warning")
else:
Expand Down Expand Up @@ -1517,10 +1514,7 @@ def _is_registration_gate_url(url: str) -> bool:
result.error_message = "处理 OAuth 回调失败"
return False

result.account_id = token_info.get("account_id", "")
result.access_token = token_info.get("access_token", "")
result.refresh_token = token_info.get("refresh_token", "")
result.id_token = token_info.get("id_token", "")
self._apply_oauth_token_info(result, token_info)
result.password = self.password or ""
result.source = "login" if self._is_existing_account else "register"
result.device_id = result.device_id or str(self.device_id or "")
Expand Down Expand Up @@ -1621,10 +1615,7 @@ def _complete_token_exchange_outlook(self, result: RegistrationResult) -> bool:
result.error_message = "处理 OAuth 回调失败"
return False

result.account_id = str(token_info.get("account_id") or result.account_id or "").strip()
result.access_token = str(token_info.get("access_token") or result.access_token or "").strip()
result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip()
result.id_token = str(token_info.get("id_token") or result.id_token or "").strip()
self._apply_oauth_token_info(result, token_info)
result.password = self.password or ""
result.source = "login" if self._is_existing_account else "register"
result.device_id = result.device_id or str(self.device_id or "")
Expand Down Expand Up @@ -1740,10 +1731,7 @@ def _capture_native_core_tokens(self, result: RegistrationResult) -> bool:
if callback_url and (not callback_has_error):
token_info = self._handle_oauth_callback(callback_url)
if token_info:
result.account_id = str(token_info.get("account_id") or result.account_id or "").strip()
result.access_token = str(token_info.get("access_token") or result.access_token or "").strip()
result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip()
result.id_token = str(token_info.get("id_token") or result.id_token or "").strip()
self._apply_oauth_token_info(result, token_info)
self._log(
"原生入口 token 抓取结果: "
f"account_id={'有' if bool(result.account_id) else '无'}, "
Expand Down Expand Up @@ -2614,6 +2602,14 @@ def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]:
self._log(f"处理 OAuth 回调失败: {e}", "error")
return None

def _apply_oauth_token_info(self, result: RegistrationResult, token_info: Dict[str, Any]) -> None:
"""统一应用 OAuth callback 返回的 token/account/workspace 信息。"""
result.account_id = str(token_info.get("account_id") or result.account_id or "").strip()
result.workspace_id = str(token_info.get("workspace_id") or result.workspace_id or "").strip()
result.access_token = str(token_info.get("access_token") or result.access_token or "").strip()
result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip()
result.id_token = str(token_info.get("id_token") or result.id_token or "").strip()

def run(self) -> RegistrationResult:
"""
执行完整的注册流程
Expand Down
108 changes: 108 additions & 0 deletions tests/test_oauth_token_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import base64
import json

from src.config.constants import EmailServiceType
from src.core.openai.oauth import OAuthManager, submit_callback_url
from src.core.register import RegistrationEngine, RegistrationResult
from src.services.base import BaseEmailService


def _jwt(payload):
header = base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode("utf-8")).decode("ascii").rstrip("=")
body = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("ascii").rstrip("=")
return f"{header}.{body}.sig"


class DummyEmailService(BaseEmailService):
def __init__(self):
super().__init__(EmailServiceType.TEMPMAIL)

def create_email(self, config=None):
return {"email": "tester@example.com", "service_id": "mailbox-1"}

def get_verification_code(self, email, email_id=None, timeout=120, pattern=r"(?<!\d)(\d{6})(?!\d)", otp_sent_at=None):
return "123456"

def list_emails(self, **kwargs):
return []

def delete_email(self, email_id):
return True

def check_health(self):
return True


def test_extract_account_info_repairs_workspace_id_with_organization_claim():
id_token = _jwt(
{
"email": "tester@example.com",
"https://api.openai.com/auth": {
"chatgpt_account_id": "acct-123",
"workspace_id": "acct-123",
"organizations": [{"id": "org-real-123"}],
},
}
)

info = OAuthManager().extract_account_info(id_token)

assert info["account_id"] == "acct-123"
assert info["organization_id"] == "org-real-123"
assert info["workspace_id"] == "org-real-123"


def test_submit_callback_url_returns_organization_and_normalized_workspace(monkeypatch):
id_token = _jwt(
{
"email": "tester@example.com",
"https://api.openai.com/auth": {
"chatgpt_account_id": "acct-234",
"workspace_id": "acct-234",
"organizations": [{"id": "org-real-234"}],
},
}
)

monkeypatch.setattr(
"src.core.openai.oauth._post_form",
lambda *_args, **_kwargs: {
"access_token": "access-token",
"refresh_token": "refresh-token",
"id_token": id_token,
"expires_in": 3600,
},
)

result = json.loads(
submit_callback_url(
callback_url="http://localhost:1455/auth/callback?code=code-1&state=state-1",
expected_state="state-1",
code_verifier="verifier-1",
)
)

assert result["account_id"] == "acct-234"
assert result["organization_id"] == "org-real-234"
assert result["workspace_id"] == "org-real-234"


def test_registration_engine_applies_workspace_id_from_oauth_callback():
engine = RegistrationEngine(DummyEmailService())
result = RegistrationResult(success=False, workspace_id="", logs=[])

engine._apply_oauth_token_info(
result,
{
"account_id": "acct-345",
"workspace_id": "org-real-345",
"access_token": "access-token",
"refresh_token": "refresh-token",
"id_token": "id-token",
},
)

assert result.account_id == "acct-345"
assert result.workspace_id == "org-real-345"
assert result.access_token == "access-token"
assert result.refresh_token == "refresh-token"