diff --git a/src/core/openai/oauth.py b/src/core/openai/oauth.py index e8dc0fa6..4049469d 100644 --- a/src/core/openai/oauth.py +++ b/src/core/openai/oauth.py @@ -122,6 +122,128 @@ def _to_int(v: Any) -> int: return 0 +def _looks_like_organization_id(value: Any) -> bool: + """判断是否为 OpenAI organization/workspace ID。""" + text = str(value or "").strip() + return text.startswith("org-") + + +def _extract_organization_id_from_items(items: Any) -> str: + """从 organizations/workspaces 列表中提取 organization_id。""" + if not isinstance(items, list): + return "" + for item in items: + if not isinstance(item, dict): + continue + organization_id = str(item.get("id") or "").strip() + if _looks_like_organization_id(organization_id): + return organization_id + return "" + + +def _extract_default_organization_id(*payloads: Any) -> str: + """ + 从多个 payload 中提取 organization_id。 + + 某些 OAuth 响应会把默认组织混在 workspace/default_workspace 字段里, + 这里统一做一层归一化,避免后续把 account_id 错当 workspace_id。 + """ + for payload in payloads: + if not isinstance(payload, dict): + continue + for key in ( + "organization_id", + "organizationId", + "default_organization_id", + "defaultOrganizationId", + "default_workspace_id", + "defaultWorkspaceId", + "workspace_id", + "workspaceId", + "id", + ): + organization_id = str(payload.get(key) or "").strip() + if _looks_like_organization_id(organization_id): + return organization_id + + organization_id = _extract_organization_id_from_items(payload.get("organizations")) + if organization_id: + return organization_id + + account_payload = payload.get("account") + if isinstance(account_payload, dict): + organization_id = _extract_default_organization_id(account_payload) + if organization_id: + return organization_id + + auth_payload = payload.get("https://api.openai.com/auth") + if isinstance(auth_payload, dict): + organization_id = _extract_default_organization_id(auth_payload) + if organization_id: + return organization_id + return "" + + +def _resolve_workspace_id( + *payloads: Any, + account_id: str = "", + organization_id: str = "", +) -> str: + """ + 从多个 payload 中解析 workspace_id,并修复 workspace_id 被 account_id 污染的情况。 + """ + normalized_account_id = str(account_id or "").strip() + normalized_organization_id = str(organization_id or "").strip() + + def _iter_workspace_candidates(payload: Any): + if not isinstance(payload, dict): + return + for key in ( + "workspace_id", + "workspaceId", + "default_workspace_id", + "defaultWorkspaceId", + "id", + ): + value = str(payload.get(key) or "").strip() + if value: + yield value + + workspace_payload = payload.get("workspace") + if isinstance(workspace_payload, dict): + workspace_id = str(workspace_payload.get("id") or "").strip() + if workspace_id: + yield workspace_id + + workspaces = payload.get("workspaces") + if isinstance(workspaces, list): + for item in workspaces: + if not isinstance(item, dict): + continue + workspace_id = str(item.get("id") or "").strip() + if workspace_id: + yield workspace_id + + account_payload = payload.get("account") + if isinstance(account_payload, dict): + yield from _iter_workspace_candidates(account_payload) + + for payload in payloads: + for candidate in _iter_workspace_candidates(payload): + if ( + normalized_organization_id + and normalized_account_id + and candidate == normalized_account_id + and normalized_organization_id != normalized_account_id + ): + continue + return candidate + + if normalized_organization_id and normalized_organization_id != normalized_account_id: + return normalized_organization_id + return "" + + def _post_form( url: str, data: Dict[str, str], @@ -290,6 +412,14 @@ def submit_callback_url( email = str(claims.get("email") or "").strip() auth_claims = claims.get("https://api.openai.com/auth") or {} account_id = str(auth_claims.get("chatgpt_account_id") or "").strip() + organization_id = _extract_default_organization_id(token_resp, claims, auth_claims) + workspace_id = _resolve_workspace_id( + token_resp, + claims, + auth_claims, + account_id=account_id, + organization_id=organization_id, + ) now = int(time.time()) expired_rfc3339 = time.strftime( @@ -302,6 +432,8 @@ def submit_callback_url( "access_token": access_token, "refresh_token": refresh_token, "account_id": account_id, + "organization_id": organization_id, + "workspace_id": workspace_id, "last_refresh": now_rfc3339, "email": email, "type": "codex", @@ -357,14 +489,23 @@ def handle_callback( return json.loads(result_json) def extract_account_info(self, id_token: str) -> Dict[str, Any]: - """从 ID Token 中提取账户信息""" + """从 ID Token 中提取账户和组织信息""" claims = _jwt_claims_no_verify(id_token) email = str(claims.get("email") or "").strip() auth_claims = claims.get("https://api.openai.com/auth") or {} account_id = str(auth_claims.get("chatgpt_account_id") or "").strip() + organization_id = _extract_default_organization_id(claims, auth_claims) + workspace_id = _resolve_workspace_id( + claims, + auth_claims, + account_id=account_id, + organization_id=organization_id, + ) return { "email": email, "account_id": account_id, + "organization_id": organization_id, + "workspace_id": workspace_id, "claims": claims - } \ No newline at end of file + } diff --git a/src/core/register.py b/src/core/register.py index f4a098c9..45dc918f 100644 --- a/src/core/register.py +++ b/src/core/register.py @@ -1309,10 +1309,7 @@ def _complete_token_exchange(self, result: RegistrationResult, require_login_otp self._log("处理 OAuth 回调,准备把 token 请出来...") token_info = self._handle_oauth_callback(callback_url) if token_info: - result.account_id = token_info.get("account_id", "") - result.access_token = token_info.get("access_token", "") or result.access_token - result.refresh_token = token_info.get("refresh_token", "") - result.id_token = token_info.get("id_token", "") + self._apply_oauth_token_info(result, token_info) elif captured: self._log("OAuth 回调失败,但 session/access 已拿到,继续后续流程", "warning") else: @@ -1517,10 +1514,7 @@ def _is_registration_gate_url(url: str) -> bool: result.error_message = "处理 OAuth 回调失败" return False - result.account_id = token_info.get("account_id", "") - result.access_token = token_info.get("access_token", "") - result.refresh_token = token_info.get("refresh_token", "") - result.id_token = token_info.get("id_token", "") + self._apply_oauth_token_info(result, token_info) result.password = self.password or "" result.source = "login" if self._is_existing_account else "register" result.device_id = result.device_id or str(self.device_id or "") @@ -1621,10 +1615,7 @@ def _complete_token_exchange_outlook(self, result: RegistrationResult) -> bool: result.error_message = "处理 OAuth 回调失败" return False - result.account_id = str(token_info.get("account_id") or result.account_id or "").strip() - result.access_token = str(token_info.get("access_token") or result.access_token or "").strip() - result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip() - result.id_token = str(token_info.get("id_token") or result.id_token or "").strip() + self._apply_oauth_token_info(result, token_info) result.password = self.password or "" result.source = "login" if self._is_existing_account else "register" result.device_id = result.device_id or str(self.device_id or "") @@ -1740,10 +1731,7 @@ def _capture_native_core_tokens(self, result: RegistrationResult) -> bool: if callback_url and (not callback_has_error): token_info = self._handle_oauth_callback(callback_url) if token_info: - result.account_id = str(token_info.get("account_id") or result.account_id or "").strip() - result.access_token = str(token_info.get("access_token") or result.access_token or "").strip() - result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip() - result.id_token = str(token_info.get("id_token") or result.id_token or "").strip() + self._apply_oauth_token_info(result, token_info) self._log( "原生入口 token 抓取结果: " f"account_id={'有' if bool(result.account_id) else '无'}, " @@ -2614,6 +2602,14 @@ def _handle_oauth_callback(self, callback_url: str) -> Optional[Dict[str, Any]]: self._log(f"处理 OAuth 回调失败: {e}", "error") return None + def _apply_oauth_token_info(self, result: RegistrationResult, token_info: Dict[str, Any]) -> None: + """统一应用 OAuth callback 返回的 token/account/workspace 信息。""" + result.account_id = str(token_info.get("account_id") or result.account_id or "").strip() + result.workspace_id = str(token_info.get("workspace_id") or result.workspace_id or "").strip() + result.access_token = str(token_info.get("access_token") or result.access_token or "").strip() + result.refresh_token = str(token_info.get("refresh_token") or result.refresh_token or "").strip() + result.id_token = str(token_info.get("id_token") or result.id_token or "").strip() + def run(self) -> RegistrationResult: """ 执行完整的注册流程 diff --git a/tests/test_oauth_token_normalization.py b/tests/test_oauth_token_normalization.py new file mode 100644 index 00000000..ca2abfc5 --- /dev/null +++ b/tests/test_oauth_token_normalization.py @@ -0,0 +1,108 @@ +import base64 +import json + +from src.config.constants import EmailServiceType +from src.core.openai.oauth import OAuthManager, submit_callback_url +from src.core.register import RegistrationEngine, RegistrationResult +from src.services.base import BaseEmailService + + +def _jwt(payload): + header = base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode("utf-8")).decode("ascii").rstrip("=") + body = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("ascii").rstrip("=") + return f"{header}.{body}.sig" + + +class DummyEmailService(BaseEmailService): + def __init__(self): + super().__init__(EmailServiceType.TEMPMAIL) + + def create_email(self, config=None): + return {"email": "tester@example.com", "service_id": "mailbox-1"} + + def get_verification_code(self, email, email_id=None, timeout=120, pattern=r"(?