Skip to content

Commit 791addf

Browse files
authored
feat(lib): Google AI 関係のエラー報告機能を強化 (#189)
1 parent 88066a5 commit 791addf

File tree

3 files changed

+199
-15
lines changed

3 files changed

+199
-15
lines changed

src/lib/dalle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl<Model: self::Model + Send + Sync> super::Generation for OpenAi<Model> {
6565
})?;
6666

6767
if res.status() != reqwest::StatusCode::OK {
68-
tracing::error!(res.status = ?res.status(), "Unexpected status code");
68+
tracing::warn!(res.status = ?res.status(), "Unexpected status code");
6969
}
7070

7171
let res = res.bytes().await.map_err(|cause| {

src/lib/google.rs

Lines changed: 197 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ impl<Model: self::Model + Send + Sync> super::Completion for Google<Model> {
7474
})?;
7575

7676
if res.status() != reqwest::StatusCode::OK {
77-
tracing::error!(res.status = ?res.status(), "Unexpected status code");
77+
tracing::warn!(res.status = ?res.status(), "Unexpected status code");
7878
}
7979

80+
let http_status = res.status().as_u16();
81+
8082
let res = res.bytes().await.map_err(|cause| {
8183
tracing::error!(?cause, "Failed to read response");
8284
"Failed to read response"
@@ -92,17 +94,89 @@ impl<Model: self::Model + Send + Sync> super::Completion for Google<Model> {
9294
"Failed to deserialize response"
9395
})?;
9496

95-
let Some([candidate]) = res.candidates else {
96-
let reason = res.prompt_feedback.block_reason;
97-
return Err(format!("Prompt blocked! reason: {reason:?}").into());
97+
let (candidates, prompt_feedback) = match res {
98+
Response::Success {
99+
candidates,
100+
prompt_feedback,
101+
} => (candidates, prompt_feedback),
102+
Response::Error { error } => {
103+
tracing::error!(?error, "Receive error response");
104+
105+
let Error {
106+
code,
107+
message,
108+
status,
109+
} = error;
110+
111+
if code != http_status as usize {
112+
tracing::warn!(error.code = %code, res.status = %http_status, "Unmatched error code");
113+
}
114+
115+
let cr = code
116+
.try_into()
117+
.ok()
118+
.and_then(|code| reqwest::StatusCode::from_u16(code).ok())
119+
.and_then(|code| code.canonical_reason())
120+
.unwrap_or("Unknown");
121+
122+
let status = match status {
123+
Either::Lhs(kind) => format!("{:?}", kind),
124+
Either::Rhs(status) => format!("raw: `{status}`"),
125+
};
126+
127+
let reason =
128+
format!("Respond with \"{cr}\" ({code}), status = {status}):\n```{message}```");
129+
130+
return Err(reason.into());
131+
}
132+
};
133+
134+
let Some([candidate]) = candidates else {
135+
let reason = prompt_feedback
136+
.block_reason
137+
.map(|br| format!("{br:?}"))
138+
.unwrap_or("<None>".to_owned());
139+
140+
let ratings = prompt_feedback
141+
.safety_ratings
142+
.iter()
143+
.map(|sr| format!("- **{}** {:?}", sr.category, sr.probability))
144+
.fold(String::new(), |c, n| c + &n + "\n");
145+
146+
let reason =
147+
format!("Prompt blocked! reason = {reason}\n### Safety ratings:\n{ratings}");
148+
149+
return Err(reason.into());
98150
};
99151

100152
if candidate.finish_reason != FinishReason::Stop {
101-
return Err(format!("Unexpected finish reason: {}", candidate.finish_reason).into());
153+
tracing::warn!(?candidate.finish_reason, "Unexpected finish reason");
102154
}
103155

156+
let Some(content) = &candidate.content else {
157+
let Candidate {
158+
finish_reason,
159+
safety_ratings,
160+
content: None,
161+
} = candidate
162+
else {
163+
unreachable!()
164+
};
165+
166+
let ratings = safety_ratings
167+
.iter()
168+
.map(|sr| format!("- **{}** {:?}", sr.category, sr.probability))
169+
.fold(String::new(), |c, n| c + &n + "\n");
170+
171+
let reason = format!(
172+
"Generation is stopped, reason = {finish_reason}.\n### Safety ratings:\n{ratings}"
173+
);
174+
175+
return Err(reason.into());
176+
};
177+
104178
let content = {
105-
let Content::Model { parts } = &candidate.content else {
179+
let Content::Model { parts } = content else {
106180
tracing::error!(?candidate.content, "Unexpected content");
107181
return Err("Failed to deserialize response".into());
108182
};
@@ -123,7 +197,7 @@ impl<Model: self::Model + Send + Sync> super::Completion for Google<Model> {
123197
let contents = messages
124198
.iter()
125199
.map(Into::into)
126-
.chain([candidate.content])
200+
.chain([candidate.content.unwrap()])
127201
.collect();
128202

129203
let metadata = super::CMetadata {
@@ -240,13 +314,60 @@ impl<'a> From<&'a super::Message> for Content<'a> {
240314
}
241315
}
242316

317+
#[derive(Debug, serde::Deserialize)]
318+
#[serde(untagged)]
319+
enum Response<'a> {
320+
#[serde(rename_all = "camelCase")]
321+
Success {
322+
// FACT: supported only `1`
323+
#[serde(borrow)]
324+
candidates: Option<[Candidate<'a>; 1]>,
325+
prompt_feedback: PromptFeedback,
326+
},
327+
#[serde(rename_all = "camelCase")]
328+
#[rustfmt::skip]
329+
Error {
330+
error: Error<'a>,
331+
},
332+
}
333+
243334
#[derive(Debug, serde::Deserialize)]
244335
#[serde(rename_all = "camelCase")]
245-
struct Response<'a> {
246-
// FACT: supported only `1`
336+
struct Error<'a> {
337+
code: usize,
338+
message: alloc::borrow::Cow<'a, str>,
247339
#[serde(borrow)]
248-
candidates: Option<[Candidate<'a>; 1]>,
249-
prompt_feedback: PromptFeedback,
340+
status: Either<ErrorKind, &'a str>,
341+
}
342+
343+
#[derive(Debug, serde::Deserialize)]
344+
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
345+
enum ErrorKind {
346+
InvalidArgument,
347+
Internal,
348+
}
349+
350+
#[derive(Debug, serde::Deserialize)]
351+
#[serde(untagged)]
352+
enum Either<L, R> {
353+
Lhs(L),
354+
Rhs(R),
355+
}
356+
357+
impl<L, R> Either<L, R> {
358+
fn lhs(&self) -> Option<&L> {
359+
match self {
360+
Self::Lhs(l) => Some(l),
361+
Self::Rhs(_) => None,
362+
}
363+
}
364+
365+
fn rhs(&self) -> Option<&R> {
366+
match self {
367+
Self::Lhs(_) => None,
368+
Self::Rhs(r) => Some(r),
369+
}
370+
}
250371
}
251372

252373
#[derive(Debug, serde::Deserialize)]
@@ -296,6 +417,46 @@ enum HarmCategory {
296417
DangerousContent,
297418
}
298419

420+
impl core::fmt::Display for HarmCategory {
421+
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
422+
match self {
423+
Self::Unspecified => {
424+
f.write_str("Category is unspecified")
425+
},
426+
Self::Derogatory => {
427+
f.write_str("Negative or harmful comments targeting identity and/or protected attribute")
428+
},
429+
Self::Toxicity => {
430+
f.write_str("Content that is rude, disrepspectful, or profane")
431+
},
432+
Self::Violence => {
433+
f.write_str("Describes scenarios depictng violence against an individual or group, or general descriptions of gore")
434+
},
435+
Self::Sexual => {
436+
f.write_str("Contains references to sexual acts or other lewd content")
437+
},
438+
Self::Medical => {
439+
f.write_str("Promotes unchecked medical advice")
440+
},
441+
Self::Dangerous => {
442+
f.write_str("Dangerous content that promotes, facilitates, or encourages harmful acts")
443+
},
444+
Self::Harassment => {
445+
f.write_str("Harasment content")
446+
},
447+
Self::HateSpeech => {
448+
f.write_str("Hate speech and content")
449+
},
450+
Self::SexuallyExplicit => {
451+
f.write_str("Sexually explicit content")
452+
},
453+
Self::DangerousContent => {
454+
f.write_str("Dangerous content")
455+
},
456+
}
457+
}
458+
}
459+
299460
#[derive(Debug, serde::Deserialize)]
300461
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
301462
enum HarmProbability {
@@ -307,12 +468,35 @@ enum HarmProbability {
307468
High,
308469
}
309470

471+
impl core::fmt::Display for HarmProbability {
472+
#[rustfmt::skip]
473+
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
474+
match self {
475+
Self::Unspecified => {
476+
f.write_str("Probability is unspecified")
477+
},
478+
Self::Negligible => {
479+
f.write_str("Content has a negligible chance of being unsafe")
480+
},
481+
Self::Low => {
482+
f.write_str("Content has a low chance of being unsafe")
483+
},
484+
Self::Medium => {
485+
f.write_str("Content has a medium chance of being unsafe")
486+
},
487+
Self::High => {
488+
f.write_str("Content has a high chance of being unsafe")
489+
},
490+
}
491+
}
492+
}
493+
310494
#[derive(Debug, serde::Deserialize)]
311495
#[serde(rename_all = "camelCase")]
312496
struct Candidate<'a> {
313-
content: Content<'a>,
497+
content: Option<Content<'a>>,
314498
finish_reason: FinishReason,
315-
// safety_ratings: Vec<SafetyRating>,
499+
safety_ratings: Vec<SafetyRating>,
316500
// citation_metadata: CitationMetadata,
317501
// // FACT: doesn't exist in a response
318502
// token_count: usize,

src/lib/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl<Model: self::Model + Send + Sync> super::Completion for OpenAi<Model> {
6868
})?;
6969

7070
if res.status() != reqwest::StatusCode::OK {
71-
tracing::error!(res.status = ?res.status(), "Unexpected status code");
71+
tracing::warn!(res.status = ?res.status(), "Unexpected status code");
7272
}
7373

7474
let res = res.bytes().await.map_err(|cause| {

0 commit comments

Comments
 (0)