Skip to content

Commit

Permalink
feat: added various training loggers using confit.Draft
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 17, 2025
1 parent 8cb9048 commit 8d055a2
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 117 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
# ruff
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.6.4'
rev: 'v0.9.6'
hooks:
- id: ruff
args: ['--config', 'pyproject.toml', '--fix', '--show-fixes']
Expand Down
18 changes: 14 additions & 4 deletions docs/training/loggers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ You can configure loggers in `edsnlp.train` via the `logger` parameter of the `t
from edsnlp.training.loggers import CSVLogger
from edsnlp.training import train

logger = CSVLogger()
logger = CSVLogger.draft()
train(..., logger=logger)
# or train(..., logger="csv")
```
Expand All @@ -34,7 +34,7 @@ You can configure loggers in `edsnlp.train` via the `logger` parameter of the `t
from edsnlp.training.loggers import CSVLogger
from edsnlp.training import train

loggers = ["tensorboard", CSVLogger(...)]
loggers = ["tensorboard", CSVLogger.draft(...)]
train(..., logger=loggers)
```

Expand All @@ -48,8 +48,18 @@ You can configure loggers in `edsnlp.train` via the `logger` parameter of the `t
...
```

`edsnlp.train` will provide a default project name and logging dir for loggers that require these parameters, but it is
recommended to set the project name explicitly in the logger configuration.
!!! note "Draft objects"

`edsnlp.train` will provide a default project name and logging dir for loggers that require these parameters, but it is
recommended to set the project name explicitly in the logger configuration. For these loggers, if you don't want to set
the project name yourself, you can either:

- call `CSVLogger.draft(...)` without the normal init parameters minus the `project_name` or `logging_dir` parameters,
which will cause a `Draft[CSVLogger]` object to be returned if some required parameters are missing
- or use `"@loggers": csv` in the config file, which will also cause a `Draft[CSVLogger]` object to be returned if some required
parameters are missing

If you do not want a `Draft` object to be returned, call `CSVLogger` directly.

The supported loggers are listed below.

Expand Down
20 changes: 10 additions & 10 deletions edsnlp/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from spacy.vocab import Vocab, create_vocab
from typing_extensions import Literal, Self

from ..core.registries import PIPE_META, FactoryMeta, PartialPipeFactory, registry
from ..core.registries import PIPE_META, DraftPipe, FactoryMeta, registry
from ..utils.collections import (
FrozenDict,
FrozenList,
Expand Down Expand Up @@ -238,9 +238,9 @@ def create_pipe(
**(config if config is not None else {}),
}
).resolve(registry=registry)
if isinstance(pipe, PartialPipeFactory):
if isinstance(pipe, DraftPipe):
if name is None:
name = signature(pipe.func).parameters.get("name").default
name = signature(pipe._func).parameters.get("name").default
if name is None or name == Parameter.empty:
name = factory
pipe = pipe.instantiate(nlp=self, path=(name,))
Expand Down Expand Up @@ -297,8 +297,8 @@ def add_pipe(
raise ValueError(
"Can't pass config or name with an instantiated component",
)
if isinstance(factory, PartialPipeFactory):
name = name or factory.kwargs.get("name")
if isinstance(factory, DraftPipe):
name = name or factory._kwargs.get("name")
factory = factory.instantiate(nlp=self, path=(name,))

pipe = factory
Expand Down Expand Up @@ -585,13 +585,13 @@ def from_config(
def _add_pipes(
self,
pipeline: Sequence[str],
components: Dict[str, PartialPipeFactory],
components: Dict[str, DraftPipe],
exclude: Container[str],
enable: Container[str],
disable: Container[str],
):
try:
components = PartialPipeFactory.instantiate(components, nlp=self)
components = DraftPipe.instantiate(components, nlp=self)
except ConfitValidationError as e:
e = ConfitValidationError(
e.raw_errors,
Expand Down Expand Up @@ -1277,9 +1277,9 @@ def load_from_huggingface(
owner, model_name = repo_id.split("/")
module_name = model_name.replace("-", "_")

assert (
len(repo_id.split("/")) == 2
), "Invalid repo_id format (expected 'owner/repo_name' format)"
assert len(repo_id.split("/")) == 2, (
"Invalid repo_id format (expected 'owner/repo_name' format)"
)
path = None
mtime = None
try:
Expand Down
53 changes: 25 additions & 28 deletions edsnlp/core/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import spacy
from confit import Config, Registry, RegistryCollection, set_default_registry
from confit.errors import ConfitValidationError, patch_errors
from confit.registry import Partial
from confit.registry import Draft
from spacy.pipe_analysis import validate_attrs

import edsnlp
Expand Down Expand Up @@ -71,14 +71,13 @@ class FactoryMeta:
T = TypeVar("T")


class PartialPipeFactory(Partial[T]):
class DraftPipe(Draft[T]):
def __init__(self, func, kwargs):
super().__init__(func, kwargs)
self.func = func
self.instantiated = None
self.error = None

def maybe_nlp(self) -> Union["PartialPipeFactory", Any]:
def maybe_nlp(self) -> Union["DraftPipe", Any]:
"""
If the factory requires an nlp argument and the user has explicitly
provided it (this is unusual, we usually expect the factory to be
Expand All @@ -91,7 +90,7 @@ def maybe_nlp(self) -> Union["PartialPipeFactory", Any]:
"""
from edsnlp.core.pipeline import Pipeline, PipelineProtocol

sig = inspect.signature(self.func)
sig = inspect.signature(self._func)
if (
not (
"nlp" in sig.parameters
Expand All @@ -100,23 +99,23 @@ def maybe_nlp(self) -> Union["PartialPipeFactory", Any]:
or sig.parameters["nlp"].annotation in (Pipeline, PipelineProtocol)
)
)
or "nlp" in self.kwargs
) and not self.search_curried_factory(self.kwargs):
return self.func(**self.kwargs)
or "nlp" in self._kwargs
) and not self.search_nested_drafts(self._kwargs):
return self._func(**self._kwargs)
return self

@classmethod
def search_curried_factory(cls, obj):
if isinstance(obj, PartialPipeFactory):
def search_nested_drafts(cls, obj):
if isinstance(obj, DraftPipe):
return obj
elif isinstance(obj, dict):
for value in obj.values():
result = cls.search_curried_factory(value)
result = cls.search_nested_drafts(value)
if result is not None:
return result
elif isinstance(obj, (tuple, list, set)):
for value in obj:
result = cls.search_curried_factory(value)
result = cls.search_nested_drafts(value)
if result is not None:
return result
return None
Expand All @@ -131,7 +130,7 @@ def instantiate(
passing in the nlp object and name to factories. Since they can be
nested, we need to add them to every factory in the config.
"""
if isinstance(self, PartialPipeFactory):
if isinstance(self, DraftPipe):
if self.error is not None:
raise self.error

Expand All @@ -140,30 +139,30 @@ def instantiate(

name = path[0] if len(path) == 1 else None
parameters = (
inspect.signature(self.func.__init__).parameters
if isinstance(self.func, type)
else inspect.signature(self.func).parameters
inspect.signature(self._func.__init__).parameters
if isinstance(self._func, type)
else inspect.signature(self._func).parameters
)
kwargs = {
key: PartialPipeFactory.instantiate(
key: DraftPipe.instantiate(
self=value,
nlp=nlp,
path=(*path, key),
)
for key, value in self.kwargs.items()
for key, value in self._kwargs.items()
}
try:
if nlp and "nlp" in parameters:
kwargs["nlp"] = nlp
if name and "name" in parameters:
kwargs["name"] = name
self.instantiated = self.func(**kwargs)
self.instantiated = self._func(**kwargs)
except ConfitValidationError as e:
self.error = e
raise ConfitValidationError(
patch_errors(e.raw_errors, path, model=e.model),
model=e.model,
name=self.func.__module__ + "." + self.func.__qualname__,
name=self._func.__module__ + "." + self._func.__qualname__,
) # .with_traceback(None)
# except Exception as e:
# obj.error = e
Expand All @@ -174,7 +173,7 @@ def instantiate(
errors = []
for key, value in self.items():
try:
instantiated[key] = PartialPipeFactory.instantiate(
instantiated[key] = DraftPipe.instantiate(
self=value,
nlp=nlp,
path=(*path, key),
Expand All @@ -190,7 +189,7 @@ def instantiate(
for i, value in enumerate(self):
try:
instantiated.append(
PartialPipeFactory.instantiate(value, nlp, (*path, str(i)))
DraftPipe.instantiate(value, nlp, (*path, str(i)))
)
except ConfitValidationError as e: # pragma: no cover
errors.append(e.raw_errors)
Expand All @@ -200,9 +199,9 @@ def instantiate(
else:
return self

def _raise_partial_error(self):
def _raise_draft_error(self):
raise TypeError(
f"This component PartialFactory({self.func}) has not been instantiated "
f"This {self} component has not been instantiated "
f"yet, likely because it was missing an `nlp` pipeline argument. You "
f"should either:\n"
f"- add it to a pipeline: `pipe = nlp.add_pipe(pipe)`\n"
Expand Down Expand Up @@ -277,9 +276,7 @@ def check_and_return():

if catalogue.check_exists(*registry_path):
func = catalogue._get(registry_path)
return lambda **kwargs: PartialPipeFactory(
func, kwargs=kwargs
).maybe_nlp()
return lambda **kwargs: DraftPipe(func, kwargs=kwargs).maybe_nlp()

# Steps 1 & 2
func = check_and_return()
Expand Down Expand Up @@ -432,7 +429,7 @@ def invoke(validated_fn, kwargs):

@wraps(fn)
def curried_registered_fn(**kwargs):
return PartialPipeFactory(registered_fn, kwargs).maybe_nlp()
return DraftPipe(registered_fn, kwargs).maybe_nlp()

return (
curried_registered_fn
Expand Down
4 changes: 2 additions & 2 deletions edsnlp/pipes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from spacy.tokens import Doc, Span

from edsnlp.core import PipelineProtocol
from edsnlp.core.registries import PartialPipeFactory
from edsnlp.core.registries import DraftPipe
from edsnlp.utils.span_getters import (
SpanGetter, # noqa: F401
SpanGetterArg, # noqa: F401
Expand Down Expand Up @@ -52,7 +52,7 @@ def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs):
and sig.parameters["nlp"].default is sig.empty
and bound.arguments.get("nlp", sig.empty) is sig.empty
):
return PartialPipeFactory(cls, bound.arguments)
return DraftPipe(cls, bound.arguments)
if nlp is inspect.Signature.empty:
bound.arguments.pop("nlp", None)
except TypeError: # pragma: no cover
Expand Down
Loading

0 comments on commit 8d055a2

Please sign in to comment.