Skip to content

Commit 79d62ef

Browse files
author
Benjamin Feuer
committed
tabpfn resilience
1 parent 75a8895 commit 79d62ef

File tree

2 files changed

+122
-31
lines changed

2 files changed

+122
-31
lines changed

marvis/data/embeddings.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,19 @@ def get_tabpfn_embeddings(
123123
tabpfn: Fitted TabPFN model
124124
y_train_sample: Labels for the sampled training set
125125
"""
126+
# Configure a stable cache location for TabPFN weights (helps on HPC clusters)
127+
try:
128+
preferred_cache_root = os.environ.get("MARVIS_CACHE_DIR") or cache_dir
129+
if preferred_cache_root:
130+
# Use XDG cache convention so tabpfn stores under <root>/tabpfn
131+
os.environ.setdefault("XDG_CACHE_HOME", os.path.abspath(preferred_cache_root))
132+
os.environ.setdefault(
133+
"TABPFN_CACHE_DIR", os.path.join(os.environ["XDG_CACHE_HOME"], "tabpfn")
134+
)
135+
except Exception:
136+
# Non-fatal; fall back to defaults
137+
pass
138+
126139
try:
127140
if task_type == "regression":
128141
from tabpfn import TabPFNRegressor
@@ -412,7 +425,34 @@ def get_tabpfn_embeddings(
412425
n_estimators=N_ensemble,
413426
ignore_pretraining_limits=True,
414427
)
415-
tabpfn.fit(X_train_sample, y_train_sample)
428+
# Fit with fallback retry if checkpoint path missing on this node
429+
try:
430+
tabpfn.fit(X_train_sample, y_train_sample)
431+
except FileNotFoundError as e:
432+
err_msg = str(e)
433+
if "tabpfn" in err_msg and (".cache/tabpfn" in err_msg or "tabpfn-v2" in err_msg):
434+
logger.warning(
435+
"TabPFN checkpoint not found at reported path. Switching cache root and retrying download."
436+
)
437+
# Choose a fallback cache root: MARVIS_CACHE_DIR or /tmp/marvis_cache
438+
fallback_root = os.environ.get("MARVIS_CACHE_DIR") or "/tmp/marvis_cache"
439+
try:
440+
os.makedirs(fallback_root, exist_ok=True)
441+
except Exception:
442+
pass
443+
os.environ["XDG_CACHE_HOME"] = os.path.abspath(fallback_root)
444+
os.environ["TABPFN_CACHE_DIR"] = os.path.join(
445+
os.environ["XDG_CACHE_HOME"], "tabpfn"
446+
)
447+
# Recreate model and retry once
448+
tabpfn = TabPFNModel(
449+
device="cuda" if torch.cuda.is_available() else "cpu",
450+
n_estimators=N_ensemble,
451+
ignore_pretraining_limits=True,
452+
)
453+
tabpfn.fit(X_train_sample, y_train_sample)
454+
else:
455+
raise
416456

417457
# Extract embeddings - Process X_train_sample normally, use chunks for test set
418458
train_embeddings_raw = tabpfn.get_embeddings(X_train_sample)

marvis/utils/model_loader.py

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,9 @@ def __init__(self, model_name: str, device: str = "auto", **kwargs):
659659
self.low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage", True)
660660
self.device_map = kwargs.get("device_map", "auto" if device == "auto" else None)
661661
self._processor = None
662+
# Fallback mode when a repo provides a VLM through a causal LM head with remote code
663+
# (e.g., custom multimodal repos that don't map to AutoModelForVision2Seq configs).
664+
self._causal_vlm_mode: bool = False
662665

663666
def load(self) -> None:
664667
"""Load VLM using transformers."""
@@ -667,7 +670,7 @@ def load(self) -> None:
667670

668671
logger.info(f"Loading {self.model_name} with transformers VLM backend")
669672

670-
# Load processor
673+
# Load processor first (used by both primary and fallback flows)
671674
self._processor = AutoProcessor.from_pretrained(
672675
self.model_name, use_fast=True, trust_remote_code=True
673676
)
@@ -727,6 +730,7 @@ def load(self) -> None:
727730
)
728731

729732
try:
733+
# Primary path: standard Vision2Seq auto-model
730734
self._model = AutoModelForVision2Seq.from_pretrained(
731735
self.model_name, **model_kwargs
732736
)
@@ -740,10 +744,41 @@ def load(self) -> None:
740744
logger.info("Moving model to MPS device...")
741745
self._model = self._model.to(torch.device("mps"))
742746

743-
logger.info(f"Successfully loaded {self.model_name} with transformers VLM")
744-
except Exception as e:
745-
logger.error(f"Failed to load {self.model_name} with transformers VLM: {e}")
746-
raise
747+
logger.info(
748+
f"Successfully loaded {self.model_name} with transformers VLM (AutoModelForVision2Seq)"
749+
)
750+
except Exception as e_primary:
751+
# Fallback path: some repos expose multimodal models via a Causal LM with remote code
752+
logger.warning(
753+
f"Primary VLM load failed for {self.model_name} ({e_primary}). "
754+
f"Attempting fallback with AutoModelForCausalLM + trust_remote_code..."
755+
)
756+
try:
757+
causal_kwargs = dict(model_kwargs)
758+
# In causal path, device_map/torch_dtype still apply
759+
self._model = AutoModelForCausalLM.from_pretrained(
760+
self.model_name, **causal_kwargs
761+
)
762+
self._causal_vlm_mode = True
763+
764+
# Move to MPS if needed
765+
if (
766+
actual_device == "mps"
767+
and hasattr(torch.backends, "mps")
768+
and torch.backends.mps.is_available()
769+
):
770+
logger.info("Moving causal VLM model to MPS device...")
771+
self._model = self._model.to(torch.device("mps"))
772+
773+
logger.info(
774+
f"Successfully loaded {self.model_name} as causal VLM (AutoModelForCausalLM)"
775+
)
776+
except Exception as e_fallback:
777+
logger.error(
778+
f"Failed to load {self.model_name} with both primary and fallback VLM paths. "
779+
f"Primary error: {e_primary}; Fallback error: {e_fallback}"
780+
)
781+
raise
747782

748783
def generate_from_conversation(
749784
self, conversation: List[Dict], config: GenerationConfig
@@ -752,53 +787,69 @@ def generate_from_conversation(
752787
if not self.is_loaded():
753788
raise RuntimeError("Model not loaded")
754789

755-
# Process the conversation
790+
# Extract all images from the conversation (support single or multiple)
791+
images: List[Any] = []
792+
for message in conversation:
793+
content = message.get("content")
794+
if isinstance(content, list):
795+
for item in content:
796+
if isinstance(item, dict) and item.get("type") == "image":
797+
if "image" in item:
798+
images.append(item["image"])
799+
800+
# Processor chat template (common to both paths)
756801
formatted_text = self._processor.apply_chat_template(
757802
conversation, add_generation_prompt=True, tokenize=False
758803
)
759804

760-
# Extract image from conversation
761-
image = None
762-
for message in conversation:
763-
if isinstance(message.get("content"), list):
764-
for content_item in message["content"]:
765-
if content_item.get("type") == "image":
766-
image = content_item.get("image")
767-
break
768-
if image:
769-
break
805+
# Build processor inputs
806+
proc_kwargs = {"text": formatted_text, "return_tensors": "pt"}
807+
if images:
808+
# Pass list of images if multiple; processor should handle both
809+
proc_kwargs["images"] = images if len(images) > 1 else images[0]
770810

771-
# Process inputs
772-
inputs = self._processor(text=formatted_text, images=image, return_tensors="pt")
811+
inputs = self._processor(**proc_kwargs)
773812

774-
# Move to device
813+
# Determine target device
775814
if hasattr(self._model, "device"):
776815
device = self._model.device
777816
else:
778-
# Try to infer device from model parameters
779817
try:
780818
device = next(self._model.parameters()).device
781819
except (StopIteration, AttributeError, RuntimeError):
782820
device = torch.device("cpu")
783821

784822
# Move inputs to the appropriate device
785823
if device.type != "cpu":
786-
inputs = {
787-
k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()
788-
}
824+
inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()}
789825

790-
# Generate
826+
# Prepare generation kwargs (use conservative defaults, compatible with varied repos)
791827
gen_kwargs = config.to_transformers_kwargs()
792-
gen_kwargs["pad_token_id"] = self._processor.tokenizer.eos_token_id
828+
829+
# Determine pad token id (prefer processor tokenizer; otherwise try model)
830+
pad_id = None
831+
try:
832+
pad_id = self._processor.tokenizer.eos_token_id
833+
except Exception:
834+
try:
835+
pad_id = getattr(self._model.config, "eos_token_id", None)
836+
except Exception:
837+
pad_id = None
838+
if pad_id is not None:
839+
gen_kwargs["pad_token_id"] = pad_id
793840

794841
with torch.no_grad():
795-
generate_ids = self._model.generate(**inputs, **gen_kwargs)
842+
outputs = self._model.generate(**inputs, **gen_kwargs)
843+
844+
# Decode response; trim input ids if present
845+
try:
846+
input_len = inputs["input_ids"].shape[1]
847+
trimmed = outputs[:, input_len:]
848+
except Exception:
849+
trimmed = outputs
796850

797-
# Decode response
798851
response = self._processor.batch_decode(
799-
generate_ids[:, inputs["input_ids"].shape[1] :],
800-
skip_special_tokens=True,
801-
clean_up_tokenization_spaces=False,
852+
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
802853
)[0]
803854

804855
return response

0 commit comments

Comments
 (0)