Skip to content

Commit c34c57d

Browse files
author
Benjamin Feuer
committed
patch chat behavior
1 parent 7035f3f commit c34c57d

File tree

4 files changed

+262
-470
lines changed

4 files changed

+262
-470
lines changed

marvis/models/marvis_tsne.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,7 +2029,13 @@ def chat(self, user_input: str, max_history: int = 10) -> str:
20292029
- Completion Rate: {prediction_context.get('completion_rate', 'N/A')}
20302030
""")
20312031

2032-
# Add visualization context if available
2032+
# Add visualization context
2033+
has_viz_image = hasattr(self, '_last_viz_image') and self._last_viz_image is not None
2034+
if has_viz_image:
2035+
context_parts.append("""
2036+
**Visualization:** The most recent t-SNE visualization from the prediction session is attached to this message. You can see the exact spatial layout of training points, test points, and the query point that was classified. Reference specific visual details (cluster positions, neighbor distributions, color patterns) in your response.
2037+
""")
2038+
20332039
if prediction_context.get('visualization_context'):
20342040
viz_context = prediction_context['visualization_context']
20352041
context_parts.append(f"""
@@ -2078,41 +2084,45 @@ def chat(self, user_input: str, max_history: int = 10) -> str:
20782084
# Generate response using the VLM
20792085
self.logger.info("Generating chat response...")
20802086

2087+
# Get the stored visualization image from the last prediction (if any)
2088+
chat_image = getattr(self, '_last_viz_image', None)
2089+
20812090
# Use the VLM wrapper interface for chat
20822091
if hasattr(self.vlm_wrapper, 'generate_response'):
20832092
# Use the standard generate_response interface
20842093
response = self.vlm_wrapper.generate_response(
20852094
text_input=chat_prompt,
2086-
image_input=None, # Text-only conversation
2095+
image_input=chat_image, # Pass last visualization if available
20872096
max_tokens=1000,
20882097
temperature=0.7 # Slightly higher temperature for conversational responses
20892098
)
2090-
elif hasattr(self.vlm_wrapper, 'generate'):
2091-
# Use the direct generate interface with proper parameters
2099+
elif hasattr(self.vlm_wrapper, 'generate_from_conversation'):
2100+
# Use conversation interface if available
20922101
from marvis.utils.model_loader import GenerationConfig
2102+
from marvis.utils.vlm_prompting import create_vlm_conversation
2103+
conversation = create_vlm_conversation(chat_image, chat_prompt) if chat_image else [{"role": "user", "content": chat_prompt}]
20932104
config = GenerationConfig(
2094-
max_new_tokens=512,
2105+
max_new_tokens=1000,
20952106
temperature=0.7,
20962107
do_sample=True,
20972108
top_p=0.9
20982109
)
2099-
response = self.vlm_wrapper.generate(
2100-
inputs=chat_prompt,
2101-
config=config
2110+
response = self.vlm_wrapper.generate_from_conversation(
2111+
conversation,
2112+
config
21022113
)
2103-
elif hasattr(self.vlm_wrapper, 'generate_from_conversation'):
2104-
# Use conversation interface if available
2114+
elif hasattr(self.vlm_wrapper, 'generate'):
2115+
# Use the direct generate interface with proper parameters
21052116
from marvis.utils.model_loader import GenerationConfig
2106-
conversation = [{"role": "user", "content": chat_prompt}]
21072117
config = GenerationConfig(
2108-
max_new_tokens=512,
2118+
max_new_tokens=1000,
21092119
temperature=0.7,
21102120
do_sample=True,
21112121
top_p=0.9
21122122
)
2113-
response = self.vlm_wrapper.generate_from_conversation(
2114-
conversation,
2115-
config
2123+
response = self.vlm_wrapper.generate(
2124+
inputs=chat_prompt,
2125+
config=config
21162126
)
21172127
else:
21182128
# Final fallback - raise informative error

marvis/models/process_one_sample.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ def process_one_sample(
566566

567567
# Process image and generate VLM response
568568
image = _process_image(classifier_instance, image)
569+
570+
# Store the last visualization image and prompt on the classifier for chat access
571+
classifier_instance._last_viz_image = image.copy()
572+
classifier_instance._last_viz_prompt = prompt
573+
569574
response = _generate_vlm_response(classifier_instance, image, prompt)
570575
prediction = _parse_prediction(response, classifier_instance, all_classes)
571576

marvis/utils/class_name_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def normalize_predictions_to_target(
603603
Returns:
604604
List of predictions converted to the same type space as y_reference.
605605
"""
606-
if not y_reference:
606+
if y_reference is None or len(y_reference) == 0:
607607
return predictions
608608

609609
tgt_example = y_reference[0]

0 commit comments

Comments
 (0)